Реализация Pytorch нейронной вероятностной языковой модели. Код для обучения и загрузки данных на основе модели языка уровня слов Pytorch.
Чтобы получить набор данных wikitext-2, запустите:
./get-data.shПример уровня слов:
./main.py train --name wiki --order 5 --batch-size 32Пример уровня персонажа:
./main.py train --name wiki-char --use-char --order 12 --emb-dim 20 --batch-size 1024Если у вас предварительно подготовленные векторы перчаток, вы можете использовать их:
./main.py train --name wiki --use-glove --glove-dir your/glove/dir --emb-dim 50Некоторые другие аргументы данных:
--lower # Lowercase all words in training data.
--no-headers # Remove all headers such as `=== History ===`. С следующими аргументами одна эпоха занимает около 45 минут:
./main.py train --name wiki --order 5 --use-glove --emb-dim 50 --hidden-dims 100
--batch-size 128 --epochs 10 # Test perplexity 224.89 
Мы можем исследовать ограничения:
./main.py train --name wiki --order 13 --emb-dim 100 --hidden-dims 500
--epochs 40 --batch-size 512 --dropout 0.5 # Test perplexity 153.12 
./main.py train --name wiki --order 13 --emb-dim 300 --hidden-dims 1400
--epochs 40 --batch-size 256 --dropout 0.65 # Test perplexity 152.64 
Для создания текста используйте:
./main.py generate --checkpoint path/to/saved/model Токен <eos> заменяется новой линией, а остальное напечатано как есть.
Другое поколение аргументов:
--temperature 0.9 # Temperature to manipulate distribution.
--start # Provide an optional start of the generated text (can be longer than order)
--no-unk # Do not generate unks, especially useful for low --temperature.
--no-sos # Do not print <sos> tokensСмотрите немного сгенерированного текста в Generate.txt.
Чтобы визуализировать обученные встроены модели, используйте:
./main.py plot --checkpoint path/to/saved/modelЭто соответствует 2D-графику T-SNE с раскраской кластера K-средних из 1000 наиболее распространенных слов в наборе данных. Требуется боке для заговора и Scikit-Learn для T-SNE и K-средних.
Смотрите пример HTML здесь. (GitHub не визуализирует файлы HTML. Для рендеринга, загрузки и открытия или использования этой ссылки.)
python>=3.6
torch==0.3.0.post4
numpy
tqdm