augmented interpretable models
v0.2
在培訓期間使用LLMS增強可解釋的模型
該回購包含在Aug-Imodels論文中重現實驗的代碼(自然通訊,2023年)。要使用一個簡單的Scikit-Learn接口使用Aug-Imodels,請使用Imodelsx庫。以下是一個快速啟動示例。
安裝: pip install imodelsx
from imodelsx import AugLinearClassifier , AugTreeClassifier , AugLinearRegressor , AugTreeRegressor
import datasets
import numpy as np
# set up data
dset = datasets . load_dataset ( 'rotten_tomatoes' )[ 'train' ]
dset = dset . select ( np . random . choice ( len ( dset ), size = 300 , replace = False ))
dset_val = datasets . load_dataset ( 'rotten_tomatoes' )[ 'validation' ]
dset_val = dset_val . select ( np . random . choice ( len ( dset_val ), size = 300 , replace = False ))
# fit model
m = AugLinearClassifier (
checkpoint = 'textattack/distilbert-base-uncased-rotten-tomatoes' ,
ngrams = 2 , # use bigrams
)
m . fit ( dset [ 'text' ], dset [ 'label' ])
# predict
preds = m . predict ( dset_val [ 'text' ])
print ( 'acc_val' , np . mean ( preds == dset_val [ 'label' ]))
# interpret
print ( 'Total ngram coefficients: ' , len ( m . coefs_dict_ ))
print ( 'Most positive ngrams' )
for k , v in sorted ( m . coefs_dict_ . items (), key = lambda item : item [ 1 ], reverse = True )[: 8 ]:
print ( ' t ' , k , round ( v , 2 ))
print ( 'Most negative ngrams' )
for k , v in sorted ( m . coefs_dict_ . items (), key = lambda item : item [ 1 ])[: 8 ]:
print ( ' t ' , k , round ( v , 2 ))參考:
@ misc { ch2022augmenting ,
title = { Augmenting Interpretable Models with LLMs during Training },
author = { Chandan Singh and Armin Askari and Rich Caruana and Jianfeng Gao },
year = { 2022 },
eprint = { 2209.11799 },
archivePrefix = { arXiv },
primaryClass = { cs.AI }
}