augmented interpretable models
v0.2
교육 중 LLM으로 해석 가능한 모델을 보강합니다
이 repo에는 Aug-Imodels 논문 (Nature Communications, 2023)의 실험을 재현하는 코드가 포함되어 있습니다. Aug-Imodels를 사용하려면 간단한 Scikit-Learn 인터페이스의 경우 Imodelsx 라이브러리를 사용하십시오. 아래는 빠른 예제입니다.
설치 : pip install imodelsx
from imodelsx import AugLinearClassifier , AugTreeClassifier , AugLinearRegressor , AugTreeRegressor
import datasets
import numpy as np
# set up data
dset = datasets . load_dataset ( 'rotten_tomatoes' )[ 'train' ]
dset = dset . select ( np . random . choice ( len ( dset ), size = 300 , replace = False ))
dset_val = datasets . load_dataset ( 'rotten_tomatoes' )[ 'validation' ]
dset_val = dset_val . select ( np . random . choice ( len ( dset_val ), size = 300 , replace = False ))
# fit model
m = AugLinearClassifier (
checkpoint = 'textattack/distilbert-base-uncased-rotten-tomatoes' ,
ngrams = 2 , # use bigrams
)
m . fit ( dset [ 'text' ], dset [ 'label' ])
# predict
preds = m . predict ( dset_val [ 'text' ])
print ( 'acc_val' , np . mean ( preds == dset_val [ 'label' ]))
# interpret
print ( 'Total ngram coefficients: ' , len ( m . coefs_dict_ ))
print ( 'Most positive ngrams' )
for k , v in sorted ( m . coefs_dict_ . items (), key = lambda item : item [ 1 ], reverse = True )[: 8 ]:
print ( ' t ' , k , round ( v , 2 ))
print ( 'Most negative ngrams' )
for k , v in sorted ( m . coefs_dict_ . items (), key = lambda item : item [ 1 ])[: 8 ]:
print ( ' t ' , k , round ( v , 2 ))참조:
@ misc { ch2022augmenting ,
title = { Augmenting Interpretable Models with LLMs during Training },
author = { Chandan Singh and Armin Askari and Rich Caruana and Jianfeng Gao },
year = { 2022 },
eprint = { 2209.11799 },
archivePrefix = { arXiv },
primaryClass = { cs.AI }
}