هذا الريبو هو تنفيذ نموذج RNN اليقظة لمهمة نمذجة اللغة.
تتم نمذجة اللغة على كل من مجموعات بيانات PennTreebank و Wikitext-02. يتم تحليل الملفات بحيث يتكون كل مثال تدريبي من جملة واحدة من المجموعة ، مبطنة بطول دفعة أقصى 35. يتم قص جمل أطول. يتم ذلك من أجل إدارة الاهتمام وحضور الكلمات فقط في الجملة (قبل Timestep t إذا في timestep t).
تم اقتراح A-RNN-LM (الشبكة العصبية المتكررة القائمة على الانتباه لنمذجة اللغة) في الأصل في حوار متماسك مع نماذج اللغة القائمة على الانتباه (Hongyuan Mei et al. 2016 ، Link) ، وفي نماذج اللغة اليقظة (Salton et al. 2017 ، Link).
يتكون النموذج من تشغيل آلية انتباه تقليدية على الحالات المخفية السابقة لطبقات (طبقة) المشفرة لترميز متجه السياق الذي يتم دمجه بعد ذلك مع الحالة المخفية الأخيرة من أجل التنبؤ بالكلمة التالية في التسلسل.
التبعيات:
python=3.7torch>=1.0.0nltkmatplotlibtensorboardX تثبيت جميع الاستخفاف وتشغيل python main.py
سيتم تنزيل مجموعات البيانات وتجهيزها مسبقًا تلقائيًا.
خيارات متعددة للتشغيل هي محتملة تشغيل python main.py --help للحصول على القائمة الكاملة.
usage: main.py [-h] [--batch-size N] [--epochs N] [--lr LR] [--patience P]
[--seed S] [--log-interval N] [--dataset [{wiki-02,ptb}]]
[--embedding-size N] [--n-layers N] [--hidden-size N]
[--positioning-embedding N] [--input-dropout D]
[--rnn-dropout D] [--decoder-dropout D] [--clip N]
[--optim [{sgd,adam,asgd}]] [--salton-lr-schedule]
[--early-stopping-patience P] [--attention]
[--no-positional-attention] [--tie-weights]
[--file-name FILE_NAME] [--parallel]
PyTorch Attentive RNN Language Modeling
optional arguments:
-h, --help show this help message and exit
--batch-size N input batch size for training (default: 64)
--epochs N number of epochs to train (default: 40)
--lr LR learning rate (default: 30.0)
--patience P patience for lr decrease (default: 5)
--seed S random seed (default: 123)
--log-interval N how many batches to wait before logging training
status (default 10)
--dataset [{wiki-02,ptb}]
Select which dataset (default: ptb)
--embedding-size N embedding size for embedding layer (default: 20)
--n-layers N layer size for RNN encoder (default: 1)
--hidden-size N hidden size for RNN encoder (default: 20)
--positioning-embedding N
hidden size for positioning generator (default: 20)
--input-dropout D input dropout (default: 0.5)
--rnn-dropout D rnn dropout (default: 0.0)
--decoder-dropout D decoder dropout (default: 0.5)
--clip N value at which to clip the norm of gradients (default:
0.25)
--optim [{sgd,adam,asgd}]
Select which optimizer (default: sgd)
--salton-lr-schedule Enables same training schedule as Salton et al. 2017
(default: False)
--early-stopping-patience P
early stopping patience (default: 25)
--attention Enable standard attention (default: False)
--no-positional-attention
Disable positional attention (default: False)
--tie-weights Tie embedding and decoder weights (default: False)
--file-name FILE_NAME
Specific filename to save under (default: uses params
to generate)
--parallel Enable using GPUs in parallel (default: False)
| نموذج | عدد المعلمات | التحقق من الصحة | اختبار الحيرة |
|---|---|---|---|
| خط الأساس LSTM (Merity et al. ، 2017) | 7.86m | 66.77 | 64.96 |
| LM اليقظة (Salton et al. 2017) | 7.06m | 79.09 | 76.56 |
| موضعي اليوناني | 6.9m | 72.69 | 70.92 |
| نموذج | عدد المعلمات | التحقق من الصحة | اختبار الحيرة |
|---|---|---|---|
| خط الأساس LSTM (Merity et al. ، 2017) | 7.86m | 72.43 | 68.50 |
| LM اليقظة (Salton et al. 2017) | 7.06m | 78.43 | 74.37 |
| موضعي اليوناني | 6.9m | 74.39 | 70.73 |
يمكنك إعادة تشغيل جميع النماذج التي ولدت الجداول أعلاه عن طريق التشغيل ببساطة:
python test.py
ومع ذلك ، يرجى ملاحظة أن بعض هذه النماذج تتناول ما يزيد عن 8 ساعات لتتلاقى على وحدة معالجة الرسومات 1080 واحدة ، وبالتالي يمكن أن يكون إجمالي وقت التشغيل في التجربة حوالي يومين.
يتم تعطيل دعم متعدد GPU بشكل افتراضي حيث تبين أن له تأثير سلبي على النتائج. علاوة على ذلك ، نظرًا لأن الدُفعات صغيرة في الممارسة العملية ، فهي ليست أسرع بكثير حيث يتم قضاء الكثير من الوقت في إرسال الترجمة إلى وحدات معالجة الرسومات المعنية.
يظهر هنا مقارنات جنبًا إلى جنب لتوزيع الاهتمام على مثال:
الكلمات الموجودة في المحور السيني هي المدخلات في كل خطوة زمنية والكلمات الموجودة في المحور ص هي الأهداف. تم تدريب كلا النموذجين على مجموعة بيانات Wikitext-02 حتى التقارب.