augmented interpretable models
v0.2
زيادة النماذج القابلة للتفسير مع LLMs أثناء التدريب
يحتوي هذا الريبو على رمز لإعادة إنتاج التجارب في ورقة Aug-Imodels (Nature Communications ، 2023). للحصول على واجهة Scikit-Learn البسيطة لاستخدام Aug-Imodels ، استخدم مكتبة Imodelsx. فيما يلي مثال QuickStart.
التثبيت: pip install imodelsx
from imodelsx import AugLinearClassifier , AugTreeClassifier , AugLinearRegressor , AugTreeRegressor
import datasets
import numpy as np
# set up data
dset = datasets . load_dataset ( 'rotten_tomatoes' )[ 'train' ]
dset = dset . select ( np . random . choice ( len ( dset ), size = 300 , replace = False ))
dset_val = datasets . load_dataset ( 'rotten_tomatoes' )[ 'validation' ]
dset_val = dset_val . select ( np . random . choice ( len ( dset_val ), size = 300 , replace = False ))
# fit model
m = AugLinearClassifier (
checkpoint = 'textattack/distilbert-base-uncased-rotten-tomatoes' ,
ngrams = 2 , # use bigrams
)
m . fit ( dset [ 'text' ], dset [ 'label' ])
# predict
preds = m . predict ( dset_val [ 'text' ])
print ( 'acc_val' , np . mean ( preds == dset_val [ 'label' ]))
# interpret
print ( 'Total ngram coefficients: ' , len ( m . coefs_dict_ ))
print ( 'Most positive ngrams' )
for k , v in sorted ( m . coefs_dict_ . items (), key = lambda item : item [ 1 ], reverse = True )[: 8 ]:
print ( ' t ' , k , round ( v , 2 ))
print ( 'Most negative ngrams' )
for k , v in sorted ( m . coefs_dict_ . items (), key = lambda item : item [ 1 ])[: 8 ]:
print ( ' t ' , k , round ( v , 2 ))مرجع:
@ misc { ch2022augmenting ,
title = { Augmenting Interpretable Models with LLMs during Training },
author = { Chandan Singh and Armin Askari and Rich Caruana and Jianfeng Gao },
year = { 2022 },
eprint = { 2209.11799 },
archivePrefix = { arXiv },
primaryClass = { cs.AI }
}