This repository contains the official code used for language modeling experiments in the paper(s):
More generally, this can be used as a language modeling toolkit in PyTorch to experiment with:
Standard Transformers
Transformer-XL
Fast Weight Programmers with different update rules and linear attention functions:
e.g. some combinations result in well known models:
This repositiory contains two implementations of fast weights.
torch.autograd.Function (see utils/fast_weight.py)While we only used the cuda implementation for all our final experiments (faster/much better GPU utilization),
torch.autograd.Function version can be useful for a quick prototyping with new extensions.
This toolkit requires PyTorch torch and Ninja ninja (to compile the cuda kernels).
The experiments for the paper were conducted with Python 3.6 and PyTorch 1.4.0 (note on Aug 24, 2023: the code also works with Python 3.11 and PyTorch 2.0.1+cu117).
More recent versions of PyTorch are not yet well supported by this toolkit which still uses torch.nn.DataParallel for multi-GPU training.
If you really need to use a more recent version of PyTorch, check the documentation
to use torch.nn.parallel.DistributedDataParallel instead. We will hopefully fix this soon, but we cannot tell exactly when.
The toolkit supports Weights & Biases for monitoring jobs. If you use it, also install wandb.
This reposity contains many lines of code taken and adapted from the following sources:
Please check files under example_scripts for general instructions and examples to train and evaluate models.
@inproceedings{schlag2021linear,
title={Linear Transformers Are Secretly Fast Weight Programmers},
author={Imanol Schlag and Kazuki Irie and J"urgen Schmidhuber},
booktitle={Proc. Int. Conf. on Machine Learning (ICML)},
address = {Virtual only},
month = jul,
year={2021}
}
@article{irie2021going,
title={Going Beyond Linear Transformers with Recurrent Fast Weight Programmers},
author={Kazuki Irie and Imanol Schlag and R'obert Csord'as and J"urgen Schmidhuber},
journal={Preprint arXiv:2106.06295},
year={2021}
}