tree prompt
1.0.0
ツリープロンプト:微調整なしの効率的なタスク適応、ツリープロンプペーパーのコード。
ツリープロンプトは、トレーニングの例を使用してプロンプトのツリーを学習して分類を行い、ベースラインアンサンブルよりも高い精度とより良い効率をもたらします。
インストール: pip install treeprompt (または、このレポとpip install -e .クローンします)
from treeprompt . treeprompt import TreePromptClassifier
import datasets
import numpy as np
from sklearn . tree import plot_tree
import matplotlib . pyplot as plt
# set up data
rng = np . random . default_rng ( seed = 42 )
dset_train = datasets . load_dataset ( 'rotten_tomatoes' )[ 'train' ]
dset_train = dset_train . select ( rng . choice (
len ( dset_train ), size = 100 , replace = False ))
dset_val = datasets . load_dataset ( 'rotten_tomatoes' )[ 'validation' ]
dset_val = dset_val . select ( rng . choice (
len ( dset_val ), size = 100 , replace = False ))
# set up arguments
prompts = [
"This movie is" ,
" Positive or Negative? The movie was" ,
" The sentiment of the movie was" ,
" The plot of the movie was really" ,
" The acting in the movie was" ,
]
verbalizer = { 0 : " Negative." , 1 : " Positive." }
checkpoint = "gpt2"
# fit model
m = TreePromptClassifier (
checkpoint = checkpoint ,
prompts = prompts ,
verbalizer = verbalizer ,
cache_prompt_features_dir = None , # 'cache_prompt_features_dir/gp2',
)
m . fit ( dset_train [ "text" ], dset_train [ "label" ])
# compute accuracy
preds = m . predict ( dset_val [ 'text' ])
print ( ' n Tree-Prompt acc (val) ->' ,
np . mean ( preds == dset_val [ 'label' ])) # -> 0.7
# compare to accuracy for individual prompts
for i , prompt in enumerate ( prompts ):
print ( i , prompt , '->' , m . prompt_accs_ [ i ]) # -> 0.65, 0.5, 0.5, 0.56, 0.51
# visualize decision tree
plot_tree (
m . clf_ ,
fontsize = 10 ,
feature_names = m . feature_names_ ,
class_names = list ( verbalizer . values ()),
filled = True ,
)
plt . show ()参照:
@ misc { morris2023tree ,
title = { Tree Prompting : Efficient Task Adaptation without Fine - Tuning },
author = { John X. Morris and Chandan Singh and Alexander M. Rush and Jianfeng Gao and Yuntian Deng },
year = { 2023 },
eprint = { 2310.14034 },
archivePrefix = { arXiv },
primaryClass = { cs.CL }
}https://github.com/csinva/tree-prompt-experimentsの論文のすべての実験を再現するための完全なコードを参照してください