Dieses Repository enthält den Code und die Daten für das folgende Papier:
Mixce: Training autoregressive Sprachmodelle durch Mischen der Vorwärts- und Rückwärtskreuzentropien
@inproceedings{zhang2023mixce,
title={MixCE: Training Autoregressive Language Models by Mixing Forward and Reverse Cross-Entropies},
author={Zhang, Shiyue and Wu, Shijie and İrsoy, Ozan and Lu, Steven and Bansal, Mohit and Dredze, Mark and Rosenberg, David},
booktitle={Proceedings of the 61th Annual Meeting of the Association for Computational Linguistics},
year={2023}
}
Code Autor: Shiyue Zhang
python -m pip install -r requirements.txtOptional: Um Versionen zu vermeiden, können Sie die Installation unter einer virtuellen Umgebung durchführen:
python -m venv yourenv
. yourenv/bin/activate # for bash, might be something else for your particular shell
python -m pip install -r requirements.txtsynthetic.py ist das Skript zum Ausführen von synthetischen Experimenten. Ausführen von Experimenten ist sehr einfach, einfach rennen:
python synthetic.py
Konfigurationen (wie Saatgut, Vokabellengröße usw.) können innerhalb des Skripts angegeben und geändert werden und unter if __name__ == '__main__': .
Es gibt einige wichtige Konfigurationen innerhalb von Synthetic.Py, die bestimmen, welche Art von synthetischen Experimenten Sie durchführen können:
real_dataset : Wenn es None ist, wird die Übergangsmatrix zufällig initialisiert. Oder wenn es sich um 'webtext' handelt, wird die Übergangsmatrix aus den vorbereiteten Übergangsmatrizen im WebText initialisiert.
Zero_percent : Bestimmt, wie viele Werte in der Übergangsmatrix 0 sind. Wenn beispielsweise zero_percent==0.5 , sind 50% der Wahrscheinlichkeiten in der Übergangsmatrix 0.
VOCAB_SIZE : Die Vokabulargröße. Wir testen 21, 51, 101, 501 oder 1001. Beachten Sie, dass 21 bedeutet, dass wir 20 normale Token (einschließlich EOS) und 1 PAD -Token haben.
Saatgut : Wir laufen für jedes Experiment 5 Samen (7, 42, 777, 4222, 99999).
LUST_FUNC : Wir testen 4 Verlustfunktionen: (1) 'two_xens' : Es wird als Mixce* in unserem Papier bezeichnet und verwendet die Golddatenverteilung P und Gumbel Softmax; (2) 'qlogq_mix' : Es ist unsere approximierte Mixce -Verlustfunktion; (3) 'two_kls' : Die Mischung von zwei KL -Abweichungen; (4) 'js' : JS -Divergenz.
Train_eta : Das Mischverhältnis für diese Verlustfunktionen. Wenn train_eta==1.0 für 'two_xens' , ist es MLE. Wenn train_eta==1.0 für 'two_kls' , ist es vorwärts KL (auch gleich MLE). Wenn train_eta==0.0 für 'two_kls' , ist es das umgekehrte KL. Wir verwenden eine allgemeine Definition von JS (siehe dieses Papier für weitere Details) und JS konvergiert zu 0, wenn train_eta näher an 0,0 oder 1,0 kommt. Wenn train_eta = 0,5, ist es die normale Definition der JS -Divergenz.
Wir bewerten synthetisch ausgebildete Bigram LMS, indem wir die gelernte Übergangsmatrix mit der Goldübergangsmatrix vergleichen. Wir verwenden zwei Metriken:
(1) AVG. JS : Wir berechnen die JS -Divergenz zwischen jeder Reihe von Gold und gelernten Übergangsmatrizen und Durchschnitt über Zeilen.
(2) AVG. 0
Die Funktion compare_parameters() in synthetic.py wird für die Berechnung dieser beiden Metriken verwendet.
Modelle werden alle unter dem Verzeichnis synthetic_logs/ Verzeichnis gespeichert. Der Name jedes Modellverzeichnisses beginnt mit der DateTime, in der das Experiment ausgeführt wurde. Unter dem Modellverzeichnis finden Sie auch die Tensorboard -Ereignisdateien sowie einen all_best_metrics.json , der die besten Metriken -Scores für jedes Mischverhältnis spart. Siehe Beispiele unter synthetic_logs/.
Die Modellbewertung wird nach jeder Epoche durchgeführt, und der beste Kontrollpunkt wird basierend auf dem Verlust am Dev -Set ausgewählt.
Schließlich haben wir für jedes Experiment die Ergebnisse von 5 Samen durchschnittlich; Und für jedes Ziel wählen wir das beste Mischverhältnis basierend auf AVG. JS.
get_synthetic_results() in Ergebnissen.Py ist eine Funktion, die durchschnittlich Ergebnisse von 5 Samen und die Ergebnisse verschiedener Mischungsverhältnisse zu AVG verwendet wird. JS.
Um get_synthetic_results() zu verwenden, müssen Sie zuerst synthetic_models.json vorbereiten, um die Modellverzeichnisse anzugeben. Ein Beispiel wird in synthetic_models.json gezeigt. Anschließend können Sie das Ergebnis des Experiments erhalten, das WebText Initialisierte Übergangsmatrix, vocab = 20 und objective = two_kls verwendet, indem Sie get_synthetic_results('webtext', '20', 'two_kls') ausführen.
Detokenizer. Sie müssen zuerst detokenizer.perl von Moses hier herunterladen und unter die data/detokenizer.perl platzieren, da die folgenden Python -Skripte davon abhängen.
Dann:
cd data
python wikitext_data.py
python webtext_data.py
curl https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz | tar xvzf -
python writingprompts_data.py
Die vorverarbeiteten Daten werden unter data/wikitext , data/webtext und data/writingPrompts gespeichert.
Klon-GPT-2-Modelle mit git lfs die der Anweisung befolgen, die durch Umarmung vorhanden ist.
git lfs install
git clone https://huggingface.co/gpt2
GPT2 ist das kleinste GPT-2-Modell. Wir experimentieren auch mit GPT2-Medium und GPT2-Large. GPT2-Large wird in der Berechnung von Mauve verwendet. Bitte laden Sie sie auch herunter:
git clone https://huggingface.co/gpt2-medium
git clone https://huggingface.co/gpt2-large
Machen Sie eine Kopie von GPT2-Large for Mauve:
cp -r gpt2-large gpt2-large-mauve
Weil wir direkt an GPT2-Large schreiben, was sich auf die Mauve-Berechnung auswirkt.
Sie können einfach Experimente durchführen, indem Sie:
python run.py
Konfigurationen können in run.py manuell angegeben werden. Siehe ein Beispiel unter if __name__ == '__main__' .
Es gibt einige wichtige Konfigurationen in Run.py :
Training_Size : Die Schulungsdatengröße, wir testen '10K' , '25K' , '50K' und '100K' ; Standardmäßig verwenden wir '50K' .
Modell : Es kann 'gpt2' , 'gpt2-meidum' oder 'gpt2-large' sein.
Datensatz : Es kann "wikitext" , "webtext" oder "writingPrompts" sein.
MIXING_RATIO : Wir suchen [0.0, 0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99, 1.0] und wählen Sie die besten mixing_ratio basierend auf Dev Set Mauve Score.
Train_batch_size, Akkumulation, Eval_batch_size : Diese Konfigurationen sollten durch die von Ihnen verwendete Plattform bestimmt werden. Wir verwenden einen einzelnen Tesla V100 -GPU (32G -Speicher), und die empfohlenen Konfigurationen in dieser Einstellung sind in run.py
Es gibt einen Diktat und drei Funktionen in Run.py :
Data_Sets {} : Es speichert die Pfade von Datendateien.
run_no_trainer () : Die Funktion zum Training und Bewertung von Modellen.
run_no_trainer_eval () : Die Funktion, die nur für die Modellbewertung verwendet wird.
run_no_trainer_turn_topp () : Die Funktion, die zum Abtauchen von Top-P-Sampling verwendet wird.
Neben Run.py stelle ich hier die anderen wichtigen Python -Skripte für Modelltraining und -bewertung vor:
GPT2.PY (die wichtigste Datei) enthält eine GPT2MixModel -Modellklasse, die unsere Mixce -Verlustfunktion implementiert .
run_clm_no_trainer.py ist das Skript zum Trainieren und Evaluieren von GPT-2-Modellen.
run_clm_no_trainer_tune_topp.py ähnelt run_clm_no_trainer.py , außer dass es nur zum Abtauchen des Hyperparameter P der Top-P-Probenahme verwendet wird.
metircs.py enthält die Metriken, die wir zur Bewertung von Modellgenerationen verwenden.
Modelle werden im train/ Verzeichnis gespeichert.
Der Name jedes Modellverzeichnisses beginnt mit der DateTime, in der das Experiment ausgeführt wurde. Nach dem Modellverzeichnis speichern wir den besten Checkpoint (ausgewählt basierend auf dem Dev -Verlust).
dev/test.sample , dev/test.sample1 , dev/test.sample2 und dev/test.human sind 3 unvoreingenommene Stichprobengenerationen und menschlichen Text.
dev/test_results.json retten die Ergebnisse von Verwirrung, Vielfalt und Wiederholung.
Nach dem Tuning von P für die Top-P-Probenahme sind dev/test.topp(p=*) Top-P-Sampling-Generationen mit unterschiedlichen p-Werten.
Nach der Berechnung von Mauve und Kohärenz (Einzelheiten siehe nächster Abschnitt), haben dev/test_mauve_coherence_*.json
Nach dem Berechnen von kontrolliertem Mauve und Kohärenz (Einzelheiten siehe nächster Abschnitt), werden dev/test_controlled_mauve_coherence_*.json kontrolliert Mauve- und Kohärenzwerte mit unterschiedlichen maximalen Längen.
Wir berichten über die Punktzahlen von 6 Metriken in unserem Artikel:
Verwirrung wird zusammen mit Modelltraining/-bewertung berechnet (siehe run_clm_no_trainer.py ).
Die Vielfalt wird von der diversity() -Funktion in metircs.py implementiert und zusammen mit dem Modelltraining/-bewertung berechnet, indem die Funktion compute_diversity_repetition() in run_clm_no_trainer.py aufgerufen wird. Beachten Sie, dass Wiederholung eine weitere Metrik ist, die wir implementiert haben, aber in unserem Artikel nicht gemeldet wurden. Es prüft, wie viel Prozent des Textes Wiederholungsschleifen sind, und gibt auch die sich wiederholende Phrasenlänge zurück.
Mauve und Kohärenz werden post-hoc-Weise unter Verwendung gespeicherter Generationsdateien berechnet. compute_mauve() und compute_coherence() in metrics.py sind zwei Helferfunktionen, um Mauve und Kohärenz zu berechnen. Sie werden von der Funktion compute_mauve_coherence() in results.py aufgerufen. Um compute_mauve_coherence() zu verwenden, müssen Sie zunächst die Models.json vorbereiten, um die Namensverzeichnisnamen für die Bewertung anzugeben.
In ähnlicher Weise können kontrollierte Mauve und kontrollierte Kohärenz post-hoc auch durch compute_controlled_mauve_coherence() in results.py berechnet werden.
| Datensatz | Modellgröße | Schulungsdatengröße | Objektiv | Umarme Face Hub -Name |
|---|---|---|---|---|
| Wikitext | GPT2-Large | 50k | MLE | shiyue/wikitext_train50k_gpt2-large_mix1.0 |
| Wikitext | GPT2-Large | 50k | Mixce (ETA = 0,1) | shiyue/wikitext_train50k_gpt2-large_mix0.1 |
| WebText | GPT2-Large | 50k | MLE | shiyue/webtext_train50k_gpt2-large_mix1.0 |
| WebText | GPT2-Large | 50k | Mixce (ETA = 0,3) | shiyue/webtext_train50k_gpt2-large_mix0.3 |
| WritingPrompts | GPT2-Large | 50k | MLE | shiyue/writingPrompts_train50k_gpt2-large_mix1.0 |
| WritingPrompts | GPT2-Large | 50k | Mixce (ETA = 0,7) | shiyue/writingPrompts_train50k_gpt2-large_mix0.7 |
Versuchen Sie es auf folgende Weise aus, vorbereitete Modelle:
>>> from gpt2 import GPT2MIXModel
>>> from transformers import GPT2Tokenizer
>>> model = GPT2MIXModel.from_pretrained("shiyue/wikitext_train50K_gpt2-large_mix1.0")
>>> tokenizer = GPT2Tokenizer.from_pretrained('shiyue/wikitext_train50K_gpt2-large_mix1.0')
>>> text = "Hey, how are you?"
>>> encoded_input = tokenizer(text, return_tensors='pt')
>>> model.eval()
>>> out_ids = model.lm.generate(inputs=encoded_input["input_ids"], max_length=50, do_sample=True)
>>> print(tokenizer.batch_decode(out_ids, skip_special_tokens=True))
Wir sind Beiträge.
Haben Sie eine gute Erfahrung mit diesem Projekt gemacht? Warum nicht etwas Liebe teilen und Code beitragen oder uns nur über Probleme informieren, die Sie damit hatten?
Wir begrüßen Ausgabenberichte hier; Wählen Sie unbedingt die richtige Problemvorlage für Ihr Problem, damit wir sicher sein können, dass Sie uns die erforderlichen Informationen zur Verfügung stellen.
Bevor Sie eine Pull -Anfrage senden, lesen Sie bitte unsere Beitragsrichtlinien.
Die folgenden zwei Dateien werden aus dem transformers -Repository ausgeliehen und übernommen und behalten daher ihre ursprünglichen Urheberrechte bei.
Dies wird ursprünglich von https://github.com/huggingface/transformers/blob/main/examples/pytorch/glanguage-modeling/run_clm_no_trainer.py abgeholt. Darüber hinaus haben wir die folgenden Änderungen angewendet:
--test_file--reduction--mixing_ratio--max_length--prompt_length--eval_prompt_length--cache_dir--do_train--do_evalpush_to_hub ".DataCollatorWithPadding anstelle des Standardkollators.do_eval " hinzu, von denen die meisten in die neue Funktion ' evaluate() ' eingehen. Diese Datei wird weiter von run_clm_no_trainer.py (siehe oben) geändert, indem die Funktion generate() aufgerufen wird, um die Option top_p zu aktivieren.
Dieses Projekt hat einen Verhaltenskodex angenommen. Wenn Sie Bedenken hinsichtlich des Code oder des Verhaltens haben, das Sie im Projekt erlebt haben, kontaktieren Sie uns bitte unter [email protected].