augmented interpretable models
v0.2
Menambah model yang dapat ditafsirkan dengan LLM selama pelatihan
Repo ini berisi kode untuk mereproduksi eksperimen dalam kertas AUG-Imodels (Nature Communications, 2023). Untuk antarmuka scikit-learn sederhana untuk menggunakan aug-imodels, gunakan pustaka iModelsX. Di bawah ini adalah contoh quickStart.
Instalasi: pip install imodelsx
from imodelsx import AugLinearClassifier , AugTreeClassifier , AugLinearRegressor , AugTreeRegressor
import datasets
import numpy as np
# set up data
dset = datasets . load_dataset ( 'rotten_tomatoes' )[ 'train' ]
dset = dset . select ( np . random . choice ( len ( dset ), size = 300 , replace = False ))
dset_val = datasets . load_dataset ( 'rotten_tomatoes' )[ 'validation' ]
dset_val = dset_val . select ( np . random . choice ( len ( dset_val ), size = 300 , replace = False ))
# fit model
m = AugLinearClassifier (
checkpoint = 'textattack/distilbert-base-uncased-rotten-tomatoes' ,
ngrams = 2 , # use bigrams
)
m . fit ( dset [ 'text' ], dset [ 'label' ])
# predict
preds = m . predict ( dset_val [ 'text' ])
print ( 'acc_val' , np . mean ( preds == dset_val [ 'label' ]))
# interpret
print ( 'Total ngram coefficients: ' , len ( m . coefs_dict_ ))
print ( 'Most positive ngrams' )
for k , v in sorted ( m . coefs_dict_ . items (), key = lambda item : item [ 1 ], reverse = True )[: 8 ]:
print ( ' t ' , k , round ( v , 2 ))
print ( 'Most negative ngrams' )
for k , v in sorted ( m . coefs_dict_ . items (), key = lambda item : item [ 1 ])[: 8 ]:
print ( ' t ' , k , round ( v , 2 ))Referensi:
@ misc { ch2022augmenting ,
title = { Augmenting Interpretable Models with LLMs during Training },
author = { Chandan Singh and Armin Askari and Rich Caruana and Jianfeng Gao },
year = { 2022 },
eprint = { 2209.11799 },
archivePrefix = { arXiv },
primaryClass = { cs.AI }
}