tree prompt
1.0.0
树提示:有效的任务改编而无需微调,为树纸纸代码。
树提示使用培训示例学习提示树以进行分类,从而提高了基线集合的精度和提高效率。
安装: pip install treeprompt (或克隆此回购和pip install -e . )。
from treeprompt . treeprompt import TreePromptClassifier
import datasets
import numpy as np
from sklearn . tree import plot_tree
import matplotlib . pyplot as plt
# set up data
rng = np . random . default_rng ( seed = 42 )
dset_train = datasets . load_dataset ( 'rotten_tomatoes' )[ 'train' ]
dset_train = dset_train . select ( rng . choice (
len ( dset_train ), size = 100 , replace = False ))
dset_val = datasets . load_dataset ( 'rotten_tomatoes' )[ 'validation' ]
dset_val = dset_val . select ( rng . choice (
len ( dset_val ), size = 100 , replace = False ))
# set up arguments
prompts = [
"This movie is" ,
" Positive or Negative? The movie was" ,
" The sentiment of the movie was" ,
" The plot of the movie was really" ,
" The acting in the movie was" ,
]
verbalizer = { 0 : " Negative." , 1 : " Positive." }
checkpoint = "gpt2"
# fit model
m = TreePromptClassifier (
checkpoint = checkpoint ,
prompts = prompts ,
verbalizer = verbalizer ,
cache_prompt_features_dir = None , # 'cache_prompt_features_dir/gp2',
)
m . fit ( dset_train [ "text" ], dset_train [ "label" ])
# compute accuracy
preds = m . predict ( dset_val [ 'text' ])
print ( ' n Tree-Prompt acc (val) ->' ,
np . mean ( preds == dset_val [ 'label' ])) # -> 0.7
# compare to accuracy for individual prompts
for i , prompt in enumerate ( prompts ):
print ( i , prompt , '->' , m . prompt_accs_ [ i ]) # -> 0.65, 0.5, 0.5, 0.56, 0.51
# visualize decision tree
plot_tree (
m . clf_ ,
fontsize = 10 ,
feature_names = m . feature_names_ ,
class_names = list ( verbalizer . values ()),
filled = True ,
)
plt . show ()参考:
@ misc { morris2023tree ,
title = { Tree Prompting : Efficient Task Adaptation without Fine - Tuning },
author = { John X. Morris and Chandan Singh and Alexander M. Rush and Jianfeng Gao and Yuntian Deng },
year = { 2023 },
eprint = { 2310.14034 },
archivePrefix = { arXiv },
primaryClass = { cs.CL }
}请参阅https://github.com/csinva/tree-prompt-experiments the Plape中的所有实验的完整代码