Ini adalah kode untuk mereproduksi percobaan dari kertas EMNLP 2021 "Kekuatan skala untuk penyetelan cepat yang efisien parameter" (Lester et al., 2021).
Model -model ini dibangun di atas T5X, yang mendefinisikan model dan loop pelatihan; Flaxformer, yang mendefinisikan perhitungan model aktual; Rami, yang mendefinisikan lapisan model tingkat rendah; dan Jax, yang memberikan eksekusi yang sebenarnya. Rincian implementasi kami dapat ditemukan di sini.
gs://{bucket-name}/path/to/item/in/bucket . Di sinilah kami akan menyimpan kumpulan data yang di -cache serta pos pemeriksaan dan hasil model. Untuk kemudahan referensi, beberapa perintah cloud yang paling umum untuk berinteraksi dengan TPU VMS # Create a Cloud TPU VM
$ gcloud alpha compute tpus tpu-vm create ${TPU_NAME}
--zone ${ZONE}
--accelerator-type v3-8
--version v2-alpha
# SSH into a Cloud TPU VM
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE}
# Delete a Cloud TPU VM
$ gcloud alpha compute tpus tpu-vm delete ${TPU_NAME} --zone ${ZONE}git clone --branch=main https://github.com/google-research/prompt-tuning
cd prompt-tuningpython3 -m pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html Jika Anda mengalami kesalahan di mana PIP mencoba menginstal versi dependensi yang lebih awal dan earliler (TensorFlow misalnya) sampai mencoba menginstal versi 0.0.0 dan kemudian gagal, coba tambahkan --use-deprecated=legacy-resolver ke perintah instalasi. Kesalahan ini terkait dengan versi yang diperlukan antara dependensi dan perilaku sering disebut backtracking. Jika Anda menggunakan bendera, ada kemungkinan bahwa versi perpustakaan yang tidak kompatibel dapat diinstal dan Anda harus mencari peringatan tentang ketidakcocokan dalam output perintah instalasi.
Catatan: Jika Anda berencana untuk meretas internal penyetelan cepat dan memerlukan instalasi yang dapat diedit (jadi perubahan dalam kode kloning digunakan saat Anda menjalankan pelatihan) menjalankan pip dengan bendera -e dan Anda mungkin perlu menghapus file pyproject.toml jika Anda mendapatkan kesalahan selama instalasi.
Untuk menjalankan tes, instal paket dengan [test] (python3 -m python3 -m pytest python3 -m pip install .[test] ...
Melatih prompt mirip dengan menyempurnakan model dengan T5X; Perbedaan utama adalah bahwa kami memiliki set file konfigurasi tuning prompt kami sendiri untuk digunakan.
Kami menyediakan skrip demo ( prompt_tuning/scripts/sst2-demo.sh ) yang memiliki semua bagian yang diperlukan untuk melatih prompt. Anda dapat menggunakan ini sebagai titik awal, atau mengatur variabel lingkungan MODEL_DIR dan TFDS_DATA_DIR dengan jalur ke ember google cloud storage Anda untuk menjalankan skrip ini secara langsung.
./prompt-tuning/prompt_tuning/scripts/sst2-demo.shUntuk membantu dengan kecepatan iterasi, kami cenderung menentukan lebih banyak opsi baris perintah daripada menggabungkan semua konfigurasi ke dalam satu file gin. Beberapa opsi catatan:
--gin_search_paths :: Daftar direktori yang terpisah koma untuk digunakan sebagai awalan jalur untuk file gin. Kita dapat menggunakan prompt_tuning.scripts.find_module ${module} untuk menemukan lokasi instalasi perpustakaan yang menggabungkan konfigurasi dengan mereka.--gin_file :: File gin untuk dimuat. Kami cenderung menggunakan jalur relatif dimulai dengan pustaka yang diinstal dengan mereka, yaitu prompt_tuning/configs/models/t5_1_1_base_prompt.gin over models/t5_1_1_base_prompt.gin untuk menghindari kebingungan. Menggunakan flag beberapa waktu dapat digunakan untuk menentukan beberapa file gin yang akan digabungkan bersama. Opsi konfigurasi apa pun yang diatur dalam beberapa file akan menggunakan nilai dari file terakhir dalam daftar.--gin.{PARAM}={VALUE} :: Bendera override umum ini akan menetapkan PARAM ke VALUE . Ini dapat digunakan untuk dengan mudah mengatur opsi konfigurasi tanpa mengharuskannya menjadi argumen baris perintah yang sebenarnya. Misalnya. --gin.utils.SaveCheckpointConfig.keep=20 akan menyimpan 20 pos pemeriksaan terakhir.Ketika model semakin besar, XL dan XXL misalnya, mereka tidak sesuai dengan 8 TPU yang datang dengan TPU VM tunggal. Dalam kasus ini kita akan membutuhkan sepotong pod TPU (informasi lebih lanjut tentang arsitektur TPU dan konfigurasi yang tersedia dapat ditemukan di sini). Perbedaan utama antara melatih prompt pada TPU VM tunggal dan pada irisan pod adalah bahwa kami sekarang memiliki beberapa VM TPU dan akan menjalankan SPMD Jax yang sama masing-masing VM, halaman ini memiliki lebih banyak informasi tentang program JAX multi-host. Panduan ini memberikan pengantar cepat untuk menjalankan program JAX pada TPU Pod Slice, tetapi kami akan mencapai poin utama di sini.
$ gcloud alpha compute tpus tpu-vm create ${TPU_NAME}
--zone ${ZONE}
--accelerator-type v3-32
--version v2-alpha--command= flag dan harus dijalankan pada semua VM kami (disebut pekerja) dengan --worker=all . $ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--zone ${ZONE}
--worker=all
--command= " git clone --branch=main https://github.com/google-research/prompt-tuning && cd prompt-tuning && "
python3 -m pip install . -f https://storage.googleapis.com/jax-releases/libtpu_releases.html Tulis skrip untuk melatih prompt Anda. Kami menyertakan skrip demo ( /prompt_tuning/scripts/sst2-xxl-demo.sh ) melatih prompt untuk menyelesaikan dataset SST2 menggunakan T5 1.1 LM100K XXL. Anda dapat menggunakan ini sebagai titik awal atau hanya mengisi jalur ke ember google cloud storage Anda untuk menentukan di mana Anda ingin menyimpan hasil Anda ( MODEL_DIR ) dan di mana untuk men -cache data TFDS ( TFDS_DATA_DIR ), atau mengaturnya sebagai variabel lingkungan.
Salin skrip pelatihan Anda setiap pekerja. Jika ini adalah pertama kalinya Anda menjalankan scp , Anda mungkin mendapatkan kesalahan, jalankan perintah ssh-add /.../.ssh/google_compute_engine dari pesan kesalahan dan coba lagi.
$ gcloud alpha compute tpus tpu-vm scp sst2-xxl-demo.sh ${TPU_NAME} :
--zone= ${ZONE}
--worker=all$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--zone ${ZONE}
--worker=all
--command= " ./sst2-xxl-demo.sh " Jika salah satu pekerja memiliki kesalahan selama pelatihan, Anda akan dibiarkan dengan proses yang menggunakan TPU pada pekerja lain. Ini akan menghentikan Anda untuk memulai kembali pekerjaan Anda sampai proses tersebut diakhiri dan melepaskan TPU. Perintah berikut harus mengakhiri semua proses ini. Anda mungkin melihat halaman kill Command Man kembali dari pekerja yang memiliki kesalahan awal.
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--zone ${ZONE}
--worker=all
--command= " sudo lsof -t /dev/accel0 | xargs kill -9 "Untuk melatih petunjuk menggunakan bagian khusus, seperti dataset Anda sendiri, ikuti instruksi T5X pada komponen khusus
Jika Anda mengemas kode Anda sebagai paket Python yang diinstalasi pip, Anda tidak akan terikat pada satu direktori tunggal, dan Anda dapat menggunakan python3 -m prompt_tuning.scripts.find_module {your_module} untuk membantu mengatur gin_search_paths sehingga gin konfigurasi yang dibundel di perpustakaan Anda dapat ditemukan. Catatan: Jika Anda berencana untuk menggabungkan konfigurasi gin dalam paket yang dapat diinstal, pastikan bahwa direktori yang berisi file konfigurasi memiliki __init__.py karena gin memerlukan file untuk berada dalam paket python.
Jika bagian dari komponen khusus Anda dapat dikonfigurasi, mereka perlu diimpor secara eksplisit dalam file gin Anda; Jika akhirnya diimpor setelah file gin diuraikan, mereka akan menyebabkan kesalahan. Jika tidak ada dependensi Anda yang mengandung gin yang dapat dikonfigurasi, Anda dapat menghindari menulis file gin dengan melewati --gin.MIXTURE_OR_TASK_MODULE="'path.to.your.module' . Ini secara otomatis akan mengimpor modul Anda dan lebih mudah untuk ketika semua yang Anda lakukan bertukar data.
Cara kami yang disarankan untuk melakukan inferensi dengan prompt adalah dengan memuat pos pemeriksaan asli yang digunakan untuk menginisialisasi model, dan prompt dari file. Seperti yang dijelaskan dalam bagian ini tentang pemuatan parsial T5X mendukung memuat beberapa parameter model sambil menginisialisasi yang lain dari awal. Kami menggunakan ini bersamaan dengan inisialisasi prompt from_array untuk memuat ulang parameter beku dari pos pemeriksaan asli dan file prompt file. configs/runs/prompt_eval.gin mengatur konfigurasi ini untuk Anda; Anda hanya perlu menyediakan PROMPT_FILE . Jika model Anda dilatih dengan salah satu file prompts/ config, Anda dapat menghapusnya dari argumen ke skrip evaluasi.
Script sst2-demo-eval.sh yang disertakan menunjukkan contoh melakukan evaluasi dengan cara ini. Yang diperlukan hanyalah mengatur variabel lingkungan EVAL_DIR dan TFDS_DATA_DIR ke jalur untuk menyimpan output evaluasi dan cache dataset tensorflow dengan hormat.
Di T5X, skrip evaluasi mengasumsikan bahwa dataset Anda memiliki label dan mengeluarkan hasil akhir dari fungsi metrik dataset Anda. Skrip inferensi tidak memerlukan label dan sebaliknya menghasilkan prediksi model Anda. Kami menyertakan file prompt_infer.gin analog untuk digunakan dengan skrip inferensi.
Jika Anda ingin melakukan inferensi atau evaluasi dengan pos pemeriksaan T5X yang diproduksi dari menjalankan pelatihan tuning cepat, Anda dapat menggunakan (eval|infer).gin config dari T5X secara langsung. Anda perlu memperbarui utils.RestoreChekcpointConfig . Anda harus mengatur path ke pos pemeriksaan baru, assignment_map=() dan fallback_to_scratch=False .
Semua model, pelatihan, evaluasi, menyimpan, memulihkan, dll. Konfigurasi dilakukan melalui gin. Lihat Gudang Gin-Config untuk pengantar umum untuk gin dan primer ini
Kami mengikuti tata letak konfigurasi T5X:
runs/ :: berisi konfigurasi untuk pelatihan model yang sebenarnya. Di sinilah hal -hal seperti dataset dan konfigurasi evaluasi pergi.architectures/ :: Berisi konfigurasi untuk cara kerja model. Di sinilah hal-hal seperti Encoder-Decoder vs Decoder-Only dan Sharing Embedding dikonfigurasi.models/ :: Berisi konfigurasi yang mengatur model parameter spesifik seperti jumlah lapisan atau ukuran tabel embedding. Ini juga mengkonfigurasi hal -hal seperti pembungkus model T5X yang digunakan.models/decoding/ :: berisi konfigurasi yang mudah digunakan untuk menukar bagaimana model menghasilkan teks selama inferensi, termasuk konfigurasi untuk pencarian balok dan pengambilan sampel nukleus.models/sizes/ :: Berisi berbagai pengaturan untuk membuat model dengan ukuran yang berbeda, ini dikombinasikan dengan versi default untuk membuat versi berukuran, misalnya, t5_1_1_prompt.gin + sizes/large.gin membuat model besar T5 1.1. Beberapa kombinasi umum sudah tersedia sebagai file gin dengan hak termasuk ( t5_1_1_large_prompt.gin untuk contoh kami di atas). Catatan: File ukuran ini harus datang setelah file model utama.prompts/ :: Direktori tambahan kami berisi konfigurasi yang mengatur variabel gin PROMPT , memungkinkan untuk memudahkan pengalihan inisialisasi prompt berdasarkan file prompt mana yang ditambahkan sebagai argumen --gin_file (perlu dilakukan setelah file models/ gin), Saat menentukan --gin_file argumen di baris perintah, urutan penting. Urutan umum di mana file gin harus ditentukan adalah:
models/*.ginprompts/*.ginmodels/sizes/*.gin*models/decoding/*.ginruns/*.gin T5X memiliki beberapa bidang yang diperlukan seperti MIXTURE_OR_TASK_NAME atau TASK_FEATURE_LENGTHS . Kami menambahkan dua lagi:
PROMPT_LENGTH :: Panjang prompt yang kami gunakan, ini digunakan di beberapa tempat yang berbeda untuk kami membutuhkannya sebagai makro gin yang dapat kami referensi di banyak tempat dan memastikan nilainya sinkron.PROMPT :: Ini adalah konfigurasi modul prompt aktual yang akan digunakan dalam subkelas PromptX FlaxFormer. Catatan: Penyetelan cepat saat ini tidak mendukung pengemasan contoh. Ini berarti bahwa panjang target maks kami hanya perlu cukup lama untuk menyesuaikan target untuk setiap contoh. Ini berarti kunci targets kami dalam pemetaan TASK_FEATURE_LENGTHS bisa jauh lebih pendek, misalnya sekitar 4 untuk banyak tugas superglue (Wang et al., 2019), dibandingkan dengan 62 yang merupakan default p5x.
Ada beberapa opsi untuk inisialisasi parameter prompt. Kami mendukung berbagai metode di Bagian 3.2 Makalah kami, serta inisialisasi dari file. Yang terakhir memungkinkan seseorang untuk melakukan hal -hal seperti kereta di Boolq mulai dari prompt yang dipelajari di MNLI.
Semua inisialisasi mengikuti API inisialisasi Flax menjadi fungsi parameter yang mengembalikan penutupan atas fungsi inisialisasi. Fungsi inisialisasi aktual selalu memiliki tanda tangan
def initializer ( rng : Array , shape : Sequence [ int ]) -> Array :
... Kami menyediakan setiap skema inisialisasi sebagai file konfigurasi gin di direktori configs/prompts . Mereka dapat digunakan dengan memasukkan file gin dengan --gin_file=path/to/configs/prompts/scheme.gin . File ini perlu muncul setelah file model utama, jika tidak, metode default (seragam acak) akan menimpa yang Anda pilih. Beberapa metode inisialisasi ini akan mengharuskan Anda untuk menetapkan nilai gin tambahan baik melalui bendera override di salah satu file gin Anda.
Seragam acak
Inisialisasi standar dan acak yang mirip dengan apa yang telah digunakan orang untuk menanamkan inisialisasi. Ini adalah default dan tidak ada file gin yang diperlukan. Skala nilai acak dapat disesuaikan dengan overridding prompt_init/linen.initializers.uniform.scale=N .
Vocab sampel
Contoh embedding token untuk digunakan sebagai inisialisasi untuk setiap posisi prompt dengan from_sample_of_embeddings initializer. Anda dapat membatasi pengambilan sampel ke embeddings n pertama dengan prompt_init/prompts.from_samples_of_embeddings.population_size parameter.
Ini dapat digunakan dengan --gin_file=prompt_tuning/configs/prompts/from_sampled_vocab.gin . Metode ini menggunakan tabel embedding yang diekstraksi dari pos pemeriksaan model awal. Anda juga dapat menyediakan file embedding Anda sendiri dengan --gin_file=prompt_tuning/configs/prompts/from_sampled_vocab_numpy.gin . Metode ini mengharuskan Anda memberikan nilai untuk EMBEDDING_FILE yang merupakan array yang tidak bagus dari tabel embedding model. Ini dapat diekstraksi dari pos pemeriksaan model menggunakan prompt_tuning.scripts.extract_variable.
Label kelas
Kami mendukung inisialisasi waktu yang cepat dengan embedding label kelas (alias verbalizer ) melalui initializer from_embedded_list . Pengguna yang menyediakan daftar kata (label kelas) untuk digunakan. Setiap kata di -tokenized oleh vocab yang disediakan; Tertanam dengan meja vocab yang disediakan; agregat, jika perlu, melintasi sub-token; dan digunakan untuk menginisialisasi langkah waktu yang cepat. Jika token yang disediakan tidak menutupi panjang prompt penuh, token yang hilang diinisialisasi menggunakan inisialisasi Fall Back yang disediakan.
Kita dapat mencocokkan kertas, di mana token prompt yang tidak terisi diisi dengan pengambilan sampel dari tabel embedding, dengan menyusun inisialisasi ini dengan yang di atas. Ini dapat digunakan dengan --gin_file=prompt_tuning/configs/prompts/from_class_labels.gin . Ini membutuhkan pengaturan CLASS_LABELS , yang merupakan daftar kata -kata yang ingin Anda embed sebagai inisialisasi yang cepat. Anda juga dapat menyediakan file embedding Anda sendiri (yang sama seperti di atas) dengan --gin_file=prompt_tuning/configs/prompts/from_class_labels_numpy.gin . Ini juga membutuhkan pengaturan EMBEDDING_FILE .
Dari string
Kami juga mendukung menginisialisasi prompt dengan embedding beberapa string, sering digunakan untuk memulai dari prompt diskrit atau deskripsi tugas. Ini menggunakan inisialisasi from_embedded_string . String ini ditopkenized oleh kosakata yang disediakan, masing -masing token terlihat di tabel embedding yang disediakan, dan representasi tertanam yang dihasilkan dari string digunakan sebagai inisialisasi yang cepat. Jika token yang disediakan tidak menutupi panjang prompt penuh, token yang hilang diinisialisasi menggunakan inisialisasi Fall Back yang disediakan.
Catatan: Kosakata hanya mengubah string menjadi urutan ID, Anda perlu memastikan bahwa string cocok dengan hasil dari setiap pemformatan teks (spasi di sekitar tanda baca, dll.) Tugas Seqio Anda melakukannya.
Dari file
Anda juga dapat memuat prompt dari file dengan Initializer from_array untuk mengaktifkan transfer lintas tugas. Ini dilakukan dengan --gin_file=prompt_tuning/configs/prompts/from_file.gin . Ini membutuhkan pengaturan PROMPT_FILE dengan jalur ke file numpy dengan prompt untuk memuat. Versi numpy dari prompt dipancarkan secara default saat pelatihan, tetapi prompt juga dapat diekstraksi dengan skrip yang disebutkan di atas.
Kami telah merilis pos pemeriksaan asli T5X dari pos pemeriksaan T5 1.1 yang memiliki 100 ribu langkah adaptasi model bahasa.
Ini dikonversi dari pos pemeriksaan tensorflow public mesh.
Kami telah merilis prompt pretrained pada berbagai tugas, dan berencana untuk menambahkannya dari waktu ke waktu.
Prompt dapat ditemukan di direktori pretrained_prompts . Dari sana setiap kelompok sub-direktori diminta oleh model tempat mereka dilatih. Cara termudah untuk merujuk petunjuk ini yang dibundel dengan perpustakaan adalah:
--PROMPT_FILE= ` python3 -m prompt_tuning.scripts.find_module prompt_tuning ` /pretrained_prompts/{MODEL_SIZE}/{PROMPT}.npy Karena keacakan yang melekat dari perhitungan paralel, ada beberapa pengaturan yang perlu cocok antara pelatihan dan evaluasi untuk mendapatkan angka yang sama persis. Setiap model sub-direktori memiliki README.md Menentukan apa pengaturan ini seharusnya. Pengaturan terpenting yang cocok adalah ukuran batch, topologi TPU, dan partisi paralelisme model. Tabel termasuk skor yang harus Anda harapkan untuk melihat apakah Anda menggunakan prompt ini di t5x.eval
Ini adalah kumpulan sumber daya tambahan tentang penyetelan cepat.
Jika Anda menggunakan pekerjaan ini sebagai titik lompat, silakan mengutip
@inproceedings { lester-etal-2021-power ,
title = " The Power of Scale for Parameter-Efficient Prompt Tuning " ,
author = " Lester, Brian and
Al-Rfou, Rami and
Constant, Noah " ,
booktitle = " Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing " ,
month = nov,
year = " 2021 " ,
address = " Online and Punta Cana, Dominican Republic " ,
publisher = " Association for Computational Linguistics " ,
url = " https://aclanthology.org/2021.emnlp-main.243 " ,
doi = " 10.18653/v1/2021.emnlp-main.243 " ,
pages = " 3045--3059 " ,
}Ini bukan produk Google yang didukung secara resmi.