This is the repository for the code used in the paper:
please cite the paper as:
@article{din2023jump,
title={Jump to Conclusions: Short-Cutting Transformers With Linear Transformations},
author={Yom Din, Alexander and Karidi, Taelin and Choshen, Leshem and Geva, Mor},
journal={arXiv preprint arXiv:2303.09435},
year={2023},
}To produce plots for gpt2 and Wikipedia sentences, run the following, in the written order:
get_wikipedia_sentences.py
(produces ./experiment/sentences/wikipedia_20K-sentences.pickle, containing 20K sentences from Wikipedia)
add_tokenization.py
(produces ./experiment/gpt2/wikipedia_tokenized_train.pickle containing the tokenizations and random token positions for the first 9000 sentences from the file produced by the previous script, and ./experiment/gpt2/wikipedia_tokenized_val.pickle containing the tokenizations and random token positions for the next 3000 sentences)
add_linreg.py
(produces ./linreg/gpt2/wikipedia/i_j.pickle where
add_plot_r2.py
(produces ./experiment/gpt2/wikipedia_r2_scores.pickle containing the ./experiments/gpt2/plots/wikipedia/r2_scores_12.pdf containing the heatmap plots for these
add_linreg_submodules.py
(produces ./linreg/gpt2/wikipedia/pi_a_b.pickle where
add_results.py
(produces ./experiment/gpt2/wikipedia_results.pickle containing (for each validation set sample) the top 10 tokens, as well as the model's surprisal of the top 1 token, according to the five mappings of the paper, at each layer; and also containing the top 10 tokens and number of layers processed when early-exiting and using the mappings
plot_results.py
(produces some plots in ./experiment/gpt2/plots/wikipedia/ based on the results in the previous file's output)
To produce plots for bert-base-uncased and Wikipedia sentences, run the following, in the written order:
get_wikipedia_sentences.py
(the same as for gpt2 above, no need to re-run)
bert_add_reps.py
(produces ./experiment/bert-base-uncased_mask/wikipedia_train.pickle containing the tokenizations, random token positions and representations of the masked random token at all layers for the first 9000 sentences from the file produced by the previous script, and ./experiment/bert-base-uncased_mask/wikipedia_val.pickle containing the tokenizations, random token positions and representations of the masked random token at all layers for the next 3000 sentences)
bert_add_linreg.py
(produces ./linreg/bert-base-uncased_mask/wikipedia/i_j.pickle where
bert_add_plot_r2.py
(produces ./experiment/bert-base-uncased_mask/wikipedia_r2_scores.pickle containing the ./experiments/bert-base-uncased_mask/plots/wikipedia/r2_scores_12.pdf containing the heatmap plots for these
bert_add_results.py
(produces ./experiment/bert-base-uncased_mask/wikipedia_results.pickle containing (for each validation set sample) the top 10 tokens, as well as the model's surprisal of the top 1 token, according to
plot_results.py (change model_folder_name='bert-base-uncased_mask' and plot_parts = False)
(produces some plots in ./experiment/bert-base-uncased_mask/plots/wikipedia/ based on the results in the previous file's output)
We also produced plots for gpt2-medium, gpt2-large, gpt2-xl, bert-large-uncased. To do that, one should modify, in a relatively stratight-forward way, the variables at the head of each script in the sequence.
The code was ran with Python 3.10.4 and the following package versions:
torch.__version__ = 1.13.1+cu117
transformers.__version__ = 4.20.1
sklearn.__version__ = 1.2.0
pickle.format_version = 4.0
datasets.__version__ = 2.5.2 # used only to fetch Wikipedia sentences
spacy.__version__ = 3.5.0 # used only to fetch Wikipedia sentences
Some of the trained matrices can be found at https://huggingface.co/sashay/linear-shortcut.