這是從EMNLP 2021論文“參數效率及時調整的規模的功率”中復制實驗的代碼(Lester等,2021)。
這些模型建立在T5X上,該模型定義了模型和訓練循環。 FlaxFormer,定義實際模型計算;亞麻,定義低級模型層;和JAX,提供實際執行。我們實施的詳細信息可以在此處找到。
gs://{bucket-name}/path/to/item/in/bucket URI讀取數據並將數據寫入此存儲桶。在這裡,我們將存儲緩存的數據集以及模型檢查點和結果。易於參考,與TPU VM互動的一些最常見的雲命令是 # Create a Cloud TPU VM
$ gcloud alpha compute tpus tpu-vm create ${TPU_NAME}
--zone ${ZONE}
--accelerator-type v3-8
--version v2-alpha
# SSH into a Cloud TPU VM
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE}
# Delete a Cloud TPU VM
$ gcloud alpha compute tpus tpu-vm delete ${TPU_NAME} --zone ${ZONE}git clone --branch=main https://github.com/google-research/prompt-tuning
cd prompt-tuningpython3 -m pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html如果您遇到了一個錯誤,即PIP試圖更早安裝,並且依賴版本的依賴版本(例如TensorFlow),直到嘗試安裝0.0.0版本,然後嘗試添加--use-deprecated=legacy-resolver to install命令。此錯誤與依賴關係的必需版本有關,該行為通常稱為回溯。如果使用該標誌,則可能會安裝不兼容的庫版本,並且應該注意有關安裝命令輸出中不匹配的警告。
注意:如果您打算在及時調整的內部進行攻擊,並且需要安裝可編輯(因此,在運行訓練時使用克隆代碼的更改)將使用-e標誌運行pip ,並且如果您在安裝過程中遇到錯誤,則可能需要刪除pyproject.toml文件。
要運行測試,請使用[test] ( python3 -m pip install .[test] ... )選項安裝軟件包,然後從克隆的存儲庫的根部運行python3 -m pytest 。
訓練提示類似於用T5X微調模型;主要區別在於我們有自己的一組及時調整配置文件。
我們提供了一個演示腳本( prompt_tuning/scripts/sst2-demo.sh ),該腳本具有訓練提示的所有必需零件。您可以將其用作起點,也可以設置MODEL_DIR和TFDS_DATA_DIR環境變量,該變量具有通往Google Cloud儲物桶的路徑,直接運行此腳本。
./prompt-tuning/prompt_tuning/scripts/sst2-demo.sh為了幫助迭代速度,我們傾向於指定更多選項命令行,而不是將所有配置捆綁到一個杜松子酒文件中。值得注意的選擇:
--gin_search_paths ::一個逗號分隔的目錄列表,用於用作杜松子酒文件的路徑前綴。我們可以使用prompt_tuning.scripts.find_module ${module}查找與它們捆綁配置的庫的安裝位置。--gin_file ::要加載的GIN文件。我們傾向於使用相對的路徑,從安裝的庫開始,即prompt_tuning/configs/models/t5_1_1_base_prompt.gin models/t5_1_1_base_prompt.gin gin,以避免任何混亂。使用標誌多個時間可用於指定將合併在一起的多個杜松子酒文件。在多個文件中設置的任何配置選項都將使用列表中的最後一個文件中的值。--gin.{PARAM}={VALUE} ::這個常規覆蓋標誌將將PARAM設置為VALUE 。這可以用於輕鬆設置配置選項,而無需它們是實際的命令行參數。例如。 --gin.utils.SaveCheckpointConfig.keep=20將保存最後20個檢查點。例如,隨著型號變大,XL和XXL例如,它們不適合單個TPU VM的8個TPU。在這些情況下,我們將需要一個TPU POD(有關TPU體系結構和可用配置的更多信息,請參見此處)。在單個TPU VM上訓練提示和POD切片上的提示之間的主要區別在於,我們現在有多個TPU VM,並且每個VM都將運行相同的SPMD JAX,此頁面具有有關多主機JAX程序的更多信息。本指南可以快速介紹在TPU POD片上運行JAX程序,但我們將在這里達到主要要點。
$ gcloud alpha compute tpus tpu-vm create ${TPU_NAME}
--zone ${ZONE}
--accelerator-type v3-32
--version v2-alpha--command= flag指定命令,並且應在我們的所有VM(稱為工人)上運行--worker=all 。 $ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--zone ${ZONE}
--worker=all
--command= " git clone --branch=main https://github.com/google-research/prompt-tuning && cd prompt-tuning && "
python3 -m pip install . -f https://storage.googleapis.com/jax-releases/libtpu_releases.html編寫腳本來訓練您的提示。我們包括了一個演示腳本( /prompt_tuning/scripts/sst2-xxl-demo.sh )火車提示使用T5 1.1 LM100K XXL求解SST2數據集。您可以將其用作起點,也可以填寫到Google Cloud儲物存儲桶的路徑以指定要保存結果的位置( MODEL_DIR )以及在哪裡緩存TFDS數據( TFDS_DATA_DIR ),或將它們設置為環境變量。
複製每個工人的培訓腳本。如果這是您第一次運行scp ,則可能會出現錯誤,請從錯誤消息中運行ssh-add /.../.ssh/google_compute_engine命令,然後重試。
$ gcloud alpha compute tpus tpu-vm scp sst2-xxl-demo.sh ${TPU_NAME} :
--zone= ${ZONE}
--worker=all$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--zone ${ZONE}
--worker=all
--command= " ./sst2-xxl-demo.sh "如果其中一名工人在培訓期間有錯誤,則將留下其他工人使用TPU的過程。這將阻止您重新啟動工作,直到處理終止並釋放TPU為止。以下命令應結束所有這些過程。您可能會看到kill Command Man頁面從有最初錯誤的工人那裡回來。
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--zone ${ZONE}
--worker=all
--command= " sudo lsof -t /dev/accel0 | xargs kill -9 "要使用自定義零件(例如您自己的數據集)訓練提示,請按照自定義組件的T5X說明進行操作
如果您將代碼作為PIP存儲python軟件包包裝,則不會綁定到一個單個目錄,並且可以使用python3 -m prompt_tuning.scripts.find_module {your_module}來幫助設置gin_search_paths ,以便在庫中插入GIN配置。注意:如果您計劃將杜松子酒配置包裝在可安裝的軟件包中,請確保包含配置文件的目錄具有__init__.py ,因為杜松子gin需要文件在python軟件包中。
如果您的自定義組件的一部分是杜松子酒可配置的,則需要在杜松子酒文件中明確導入;如果他們最終在解析杜松子酒文件後被導入,則會導致錯誤。如果您的依賴項都不包含Gin Configurables,則可以避免通過傳遞--gin.MIXTURE_OR_TASK_MODULE="'path.to.your.module'編寫杜松子酒文件。這將自動導入您的模塊,並且在您所做的所有操作時都很方便地交換了數據集。
我們建議使用提示進行推斷的建議方法是加載用於初始化模型的原始檢查點以及文件中的提示。如本節所述,有關部分加載T5X的人數支持加載某些模型參數,同時從頭開始初始化其他參數。我們將其與from_array提示初始化器結合使用,以從原始檢查點重新加載冷凍參數,並提示文件一個文件。 configs/runs/prompt_eval.gin為您設置此配置;您只需要提供PROMPT_FILE 。如果您的模型接受了任何prompts/配置文件的培訓,則可以將其從參數中刪除到評估腳本。
隨附的sst2-demo-eval.sh腳本顯示了這樣做評估的示例。所需要做的就是將EVAL_DIR和TFDS_DATA_DIR環境變量設置為以路徑存儲評估輸出和TensorFlow數據集的路徑。
在T5X中,評估腳本假設您的數據集具有標籤,並輸出數據集的度量功能的最終結果。推理腳本不需要標籤,而是輸出模型的預測。我們包括一個類似的prompt_infer.gin文件,用於與推理腳本一起使用。
如果要使用迅速調整訓練運行產生的T5X檢查點進行推理或評估,則可以直接使用T5X的(eval|infer).gin配置。不過,您需要更新utils.RestoreChekcpointConfig 。您應該將path設置為新的檢查點, assignment_map=()和fallback_to_scratch=False 。
所有型號,培訓,評估,節省,還原等。配置都是通過杜松進行的。有關杜松子酒的一般介紹,請參見杜松子酒庫庫庫
我們遵循T5X配置佈局:
runs/ ::包含用於模型實際培訓的配置。這是數據集和評估配置之類的東西。architectures/ ::包含模型工作方式的配置。這是配置諸如Encoder-decoder vs on foldecoder vs-for-decoder和嵌入共享之類的地方。models/ ::包含設置模型特定參數的配置,例如圖層數量或嵌入式表的大小。它還配置了使用T5X型號包裝器之類的東西。models/decoding/ ::包含易於使用的配置來交換推理過程中模型如何生成文本的方式,包括用於梁搜索和核採樣的配置。models/sizes/ ::包含各種設置以創建不同尺寸的模型,這些設置與默認版本相結合以創建尺寸的版本,例如, t5_1_1_prompt.gin + sizes/large.gin創建T5 1.1大型模型。一些常見的組合已作為帶有權利的杜松子酒文件可用的,包括( t5_1_1_large_prompt.gin ,對於上面的示例)。注意:這些大小文件需要在主型號文件之後出現。prompts/ ::我們的額外目錄包含設置PROMPT GIN變量的配置,可以輕鬆切換基於提示初始化的提示文件作為--gin_file參數(需要在models/ GIN文件之後出現),在命令行中指定--gin_file參數時,順序很重要。必須指定杜松子酒文件的一般順序是:
models/*.ginprompts/*.ginmodels/sizes/*.gin*models/decoding/*.ginruns/*.ginT5X具有一些必需的字段,例如MIXTURE_OR_TASK_NAME或TASK_FEATURE_LENGTHS 。我們再增加兩個:
PROMPT_LENGTH ::我們正在使用的提示的長度,它在幾個不同的位置使用,我們要求它作為杜松子酒宏我們可以在多個位置引用並確保值是同步的。PROMPT ::這是將在FlaxFormer PromptX子類中使用的實際提示模塊的配置。注意:提示調整當前不支持示例包裝。這意味著我們的最大目標長度只需要足夠長的時間就可以適合每個示例的目標。這意味著我們的TASK_FEATURE_LENGTHS映射中的targets鍵可能要短得多,例如,對於許多超級lue(Wang等,2019)任務約為4個,而P5X默認值為62。
提示參數的初始化有幾個選項。我們支持第3.2節中的各種方法以及文件初始化。後者允許人們從在MNLI上學習的提示開始,例如在Boolq上進行訓練。
所有初始化器都遵循亞麻初始化器API,它是一個參數化函數,該函數返回閉合初始化函數。實際的初始化功能始終具有
def initializer ( rng : Array , shape : Sequence [ int ]) -> Array :
...我們在configs/prompts目錄中提供每個初始化方案作為杜松子述配置文件。可以通過將gin文件與--gin_file=path/to/configs/prompts/scheme.gin一起使用。此文件需要在主模型文件之後出現,否則默認(隨機統一)方法將覆蓋您選擇的一個方法。這些初始化方法中的某些方法將要求您設置額外的杜松子值值,儘管其中一個杜松子酒文件中的覆蓋標誌。
隨機均勻
標準的隨機初始化類似於人們用來嵌入初始化的初始化。這是默認值,不需要杜松子酒文件。隨機值的比例可以通過覆蓋prompt_init/linen.initializers.uniform.scale=N來調整。
採樣詞彙
樣本一個令牌嵌入以用作每個提示位置的初始化,並使用from_sample_of_embeddings Initializer。您可以使用prompt_init/prompts.from_samples_of_embeddings.population_size將採樣限制為第一個n嵌入。
可以與--gin_file=prompt_tuning/configs/prompts/from_sampled_vocab.gin一起使用。此方法使用從初始模型檢查點提取的嵌入式表。您還可以提供自己的嵌入文件,其中--gin_file=prompt_tuning/configs/prompts/from_sampled_vocab_numpy.gin 。此方法要求您為EMBEDDING_FILE提供一個值,該值是模型嵌入式表的數組。可以使用stript_tuning.scripts.extract_variable從模型檢查點提取這一點。
班級標籤
我們通過from_embedded_list initializer嵌入類標籤(又稱Verbalizers )的嵌入來初始化提示時間段。用戶提供要使用的單詞(類標籤)列表。每個單詞都被提供的詞彙示意。帶有提供的詞彙表;跨亞tokens(如果需要)匯總;並用於初始化及時的時間步長。如果所提供的令牌不涵蓋完整的及時長度,則使用提供的倒退初始化器初始化丟失的令牌。
我們可以匹配紙張,在該紙張中,通過將此初始化與上面的初始化組成,在嵌入式表中填充未填充的提示令牌。它可以與--gin_file=prompt_tuning/configs/prompts/from_class_labels.gin一起使用。這需要設置CLASS_LABELS ,這是您要嵌入為及時初始化的單詞的列表。您還可以使用--gin_file=prompt_tuning/configs/prompts/from_class_labels_numpy.gin提供自己的嵌入文件(與上述相同)。此外,這需要設置EMBEDDING_FILE 。
從字符串
我們還支持通過嵌入一些字符串的提示來初始化提示,通常用於從離散提示或任務描述開始。這使用了from_embedded_string Initializer。該字符串由提供的詞彙進行標記化,每個令牌都在提供的嵌入式表中查找,並將其結果的嵌入式表示形式用作及時的初始化。如果所提供的令牌不涵蓋完整的及時長度,則使用提供的倒退初始化器初始化丟失的令牌。
注意:詞彙僅將字符串轉換為ID序列,您將需要確保字符串匹配您SEQIO任務所做的任何文本格式(標點符號等)的結果。
從文件
您還可以從文件中加載來自from_array Initializer的文件以啟用跨任務的傳輸。這是通過--gin_file=prompt_tuning/configs/prompts/from_file.gin完成的。這需要設置PROMPT_FILE ,並帶有通往numpy文件的路徑,並加載提示。訓練時默認情況下會發出Numpy版本,但是也可以使用上述腳本提取提示。
我們已經發布了具有100K語言模型適應性步驟的T5 1.1檢查點的T5X本機檢查點。
這些是從公共網格Tensorflow檢查點轉換的。
我們已經發布了有關各種任務的預估計提示,併計劃隨著時間的推移將其添加到它們。
提示可以在pretrained_prompts目錄中找到。從那裡,每個子目錄組都按照他們訓練的模型提示。引用與圖書館捆綁的這些提示的最簡單方法是:
--PROMPT_FILE= ` python3 -m prompt_tuning.scripts.find_module prompt_tuning ` /pretrained_prompts/{MODEL_SIZE}/{PROMPT}.npy由於並行計算的固有隨機性,需要在訓練和評估之間匹配一些設置以獲得完全相同的數字。每個模型子目錄都有一個README.md指定這些設置應該是什麼。要匹配的最重要的設置是批處理大小,TPU拓撲和模型並行分區。這些表包括您應該期望查看的分數,如果您在t5x.eval中使用這些提示
這是有關及時調整的其他資源的集合。
如果您將此作品用作跳躍點,請引用
@inproceedings { lester-etal-2021-power ,
title = " The Power of Scale for Parameter-Efficient Prompt Tuning " ,
author = " Lester, Brian and
Al-Rfou, Rami and
Constant, Noah " ,
booktitle = " Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing " ,
month = nov,
year = " 2021 " ,
address = " Online and Punta Cana, Dominican Republic " ,
publisher = " Association for Computational Linguistics " ,
url = " https://aclanthology.org/2021.emnlp-main.243 " ,
doi = " 10.18653/v1/2021.emnlp-main.243 " ,
pages = " 3045--3059 " ,
}這不是官方支持的Google產品。