RuTaBERT
IVMEM2024
نموذج لحل مشكلة توضيح نوع العمود مع BERT ، تدرب على مجموعة بيانات RWT-RUTABERT.
تحتوي مجموعة بيانات RWT-Rutabert على 1 441 349 من أعمدة اللغة الروسية. مع رؤوس مطابقة 170 dbpedia الأنواع الدلالية. يحتوي على تقسيم قطار / اختبار ثابت:
| ينقسم | الأعمدة | الطاولات | متوسط. أعمدة لكل جدول |
|---|---|---|---|
| امتحان | 115 448 | 55 080 | 2.096 |
| يدرب | 1 325 901 | 633 426 | 2.093 |
قمنا بتدريب Rutabert مع استراتيجيتين التسلسلية الجدول:
النتائج القياسية على مجموعة بيانات RWT-Rutabert:
| استراتيجية التسلسل | micro-F1 | ماكرو-F1 | مرجح |
|---|---|---|---|
| متعدد العمود | 0.962 | 0.891 | 0.9621 |
| العمود المجاور | 0.964 | 0.904 | 0.9639 |
معلمات التدريب:
| المعلمة | قيمة |
|---|---|
| حجم الدُفعة | 32 |
| الحقبة | 30 |
| وظيفة الخسارة | المتقاطع |
| GD Optimizer | Adamw (LR = 5E-5 ، EPS = 1E-8) |
| GPU's | 4 NVIDIA A100 (80 غيغابايت) |
| بذرة عشوائية | 2024 |
| انقسام التحقق | 5 ٪ |
?RuTaBERT
┣ checkpoints
┃ ┗ Saved PyTorch models `.pt`
┣ data
┃ ┣ inference
┃ ┃ ┗ Tabels to inference `.csv`
┃ ┣ test
┃ ┃ ┗ Test dataset files `.csv`
┃ ┣ train
┃ ┃ ┗ Train dataset files `.csv`
┃ ┗ Directory for storing dataset files.
┣ dataset
┃ ┗ Dataset wrapper classes, dataloaders
┣ logs
┃ ┗ Log files (train / test / error)
┣ model
┃ ┗ Model and metrics
┣ trainer
┃ ┗ Trainer
┣ utils
┃ ┗ Helper functions
┗ Entry points (train.py, test.py, inference.py), configuration, etc.
يمكن العثور على تكوين النموذج في ملف config.json .
وردت معلمات وسيطة configuratoin أدناه:
| دعوى | وصف |
|---|---|
| num_labels | عدد الملصقات المستخدمة للتصنيف |
| num_gpu | عدد وحدات معالجة الرسومات التي يجب استخدامها |
| save_period_in_epochs | العدد الذي يميز الدورية التي يتم حفظ نقطة التفتيش (في الحقبة) |
| المقاييس | مقاييس التصنيف المستخدمة |
| pretRained_Model_Name | اسم اختصار Bert من Huggingface |
| table_serialization_type | طريقة تسلسل الجدول في تسلسل |
| batch_size | حجم الدُفعة |
| num_epochs | عدد عصر التدريب |
| عشوائي | بذرة عشوائية |
| logs_dir | دليل للتسجيل |
| Train_log_filename | اسم الملف لتسجيل القطار |
| test_log_filename | اسم الملف لتسجيل الاختبار |
| start_from_checkpoint | العلم لبدء التدريب من نقطة التفتيش |
| checkpoint_dir | دليل لتخزين نقاط التفتيش من النموذج |
| checkpoint_name | اسم ملف نقطة التفتيش (حالة النموذج) |
| Interference_model_name | اسم ملف نموذج للاستدلال |
| الاستدلال | دليل لتخزين جداول الاستدلال .csv |
| dataloader.valid_split | مبلغ تقسيم مجموعة التحقق من الصحة |
| dataloader.num_workers | عدد عمال Dataloader |
| dataset.num_rows | عدد الصفوف القابلة للقراءة في مجموعة البيانات ، إذا قرأت null جميع الصفوف في الملفات |
| Dataset.data_dir | دليل لتخزين ملفات القطار/الاختبار/الاستدلال |
| Dataset.train_path | دليل لتخزين ملفات مجموعة بيانات القطار .csv |
| dataset.test_path | direcotry لتخزين ملفات مجموعة بيانات الاختبار .csv |
نوصي بتغيير المعلمات Theese فقط:
num_gpu - أي رقم إيجابي + {0}. 0 وقف للتدريب / الاختبار على وحدة المعالجة المركزية.save_period_in_epochs - أي رقم عدد صحيح موجب ، يقيس في الحقبة.table_serialization_type - "column_wise" أو "table_wise".pretrained_model_name - أسماء Bert Shorcut من نماذج Pytorch Pytorch.batch_size - أي رقم عدد صحيح موجب.num_epochs - أي رقم عدد صحيح موجب.random_seed - أي رقم عدد صحيح.start_from_checkpoint - "true" أو "false".checkpoint_name - أي اسم للنموذج ، تم حفظه في دليل checkpoint .inference_model_name - أي اسم للنموذج ، يتم حفظه في دليل checkpoint . لكننا نوصي باستخدام أفضل النماذج: [model_best_f1_weighted.pt ، model_best_f1_macro.pt ، model_best_f1_micro.pt].dataloader.valid_split - عدد حقيقي ضمن النطاق [0.0 ، 1.0] (0.0 يقف 0 ٪ من مجموعة فرعية القطار ، 0.5 يقف 50 ٪ من مجموعة القطار الفرعية). أو رقم عدد صحيح موجب (يشير إلى عدد ثابت من مجموعة فرعية التحقق من الصحة).dataset.num_rows - "NULL" تعني قراءة جميع الخطوط في ملفات البيانات. الصدفة الإيجابية تعني عدد الخطوط التي يجب قراءتها في ملفات مجموعة البيانات. قبل التدريب / اختبار النموذج الذي تحتاجه:
├── src
│ ├── RuTaBERT
│ ├── RuTaBERT-Dataset
│ │ ├── move_dataset.sh
move_dataset.sh من مستودع مجموعة البيانات ، لنقل ملفات مجموعة البيانات إلى دليل data RUTABERT: RuTaBERT-Dataset$ ./move_dataset.shconfig.json قبل التدريب. يدعم Rutabert التدريب / الاختبار محليًا وداخل حاوية Docker. كما يدعم Slurm Workload Manager.
RuTaBERT$ virtualenv venvأو
RuTaBERT$ python -m virtualenv venvRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 train.py 2> logs/error_train.log &&
python3 test.py 2> logs/error_test.logcheckpoint .logs/ الدليل ( training_results.csv ، train.log ، test.log ، error_train.log ، error_test.log ).متطلبات:
RuTaBERT$ sudo docker build -t rutabert .RuTaBERT$ sudo docker run -d --runtime=nvidia --gpus=all
--mount source=rutabert_logs,target=/app/rutabert/logs
--mount source=rutabert_checkpoints,target=/app/rutabert/checkpoints
rutabertRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_checkpoints/_data ./checkpointsRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_logs/_data ./logscheckpoint .logs/ الدليل ( training_results.csv ، train.log ، test.log ، error_train.log ، error_test.log ).RuTaBERT$ virtualenv venvأو
RuTaBERT$ python -m virtualenv venvRuTaBERT$ sbatch run.slurmRuTaBERT$ squeuecheckpoint .logs/ دليل ( train.log ، test.log ، error_train.log ، error_test.log ). data/test .RuTaBERT$ ./download.sh table_wiseأو
RuTaBERT$ ./download.sh column_wiseconfig.json .RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 test.py 2> logs/error_test.loglogs/ الدليل ( test.log ، error_test.log ). data/inference .RuTaBERT$ ./download.sh table_wiseأو
RuTaBERT$ ./download.sh column_wiseconfig.jsonRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 inference.pydata/inference/result.csv