RuTaBERT
IVMEM2024
แบบจำลองสำหรับการแก้ปัญหาของคำอธิบายประกอบประเภทคอลัมน์กับ Bert ได้รับการฝึกฝนในชุดข้อมูล RWT-Rutabert
ชุดข้อมูล RWT-Rutabert ประกอบด้วย 1 441 349 คอลัมน์จากตารางภาษารัสเซีย Wikipedia ด้วยส่วนหัวที่ตรงกับประเภทความหมาย 170 dbpedia มันมีการแยกรถไฟ / ทดสอบแยก:
| แยก | คอลัมน์ | โต๊ะ | AVG คอลัมน์ต่อตาราง |
|---|---|---|---|
| ทดสอบ | 115 448 | 55 080 | 2.096 |
| รถไฟ | 1 325 901 | 633 426 | 2.093 |
เราฝึก Rutabert ด้วยกลยุทธ์การทำให้เป็นอนุกรมสองตาราง:
ผลการวัดผลในชุดข้อมูล RWT-Rutabert:
| กลยุทธ์การทำให้เป็นอนุกรม | Micro-F1 | มาโคร F1 | ถ่วงน้ำหนัก -F1 |
|---|---|---|---|
| คอลัมน์หลายคอลัมน์ | 0.962 | 0.891 | 0.9621 |
| คอลัมน์ใกล้เคียง | 0.964 | 0.904 | 0.9639 |
พารามิเตอร์การฝึกอบรม:
| พารามิเตอร์ | ค่า |
|---|---|
| ขนาดแบทช์ | 32 |
| ยุค | 30 |
| ฟังก์ชันการสูญเสีย | ข้ามจุดกำเนิด |
| GD Optimizer | ADAMW (LR = 5E-5, EPS = 1E-8) |
| GPU | 4 Nvidia A100 (80 GB) |
| เมล็ดสุ่ม | 2024 |
| แยกการตรวจสอบความถูกต้อง | 5% |
?RuTaBERT
┣ checkpoints
┃ ┗ Saved PyTorch models `.pt`
┣ data
┃ ┣ inference
┃ ┃ ┗ Tabels to inference `.csv`
┃ ┣ test
┃ ┃ ┗ Test dataset files `.csv`
┃ ┣ train
┃ ┃ ┗ Train dataset files `.csv`
┃ ┗ Directory for storing dataset files.
┣ dataset
┃ ┗ Dataset wrapper classes, dataloaders
┣ logs
┃ ┗ Log files (train / test / error)
┣ model
┃ ┗ Model and metrics
┣ trainer
┃ ┗ Trainer
┣ utils
┃ ┗ Helper functions
┗ Entry points (train.py, test.py, inference.py), configuration, etc.
การกำหนดค่าโมเดลสามารถพบได้ในไฟล์ config.json
พารามิเตอร์อาร์กิวเมนต์ configuratoin แสดงอยู่ด้านล่าง:
| การโต้แย้ง | คำอธิบาย |
|---|---|
| num_labels | จำนวนฉลากที่ใช้สำหรับการจำแนกประเภท |
| num_gpu | จำนวน GPU ที่จะใช้ |
| save_period_in_epochs | ตัวเลขที่แสดงถึงการบันทึกจุดตรวจสอบเป็นระยะเวลาใด (ในยุค) |
| ตัวชี้วัด | ตัวชี้วัดการจำแนกประเภทที่ใช้คือ |
| pretrained_model_name | ชื่อทางลัดเบิร์ตจาก HuggingFace |
| table_serialization_type | วิธีการทำให้ตารางเป็นลำดับในลำดับ |
| batch_size | ขนาดแบทช์ |
| num_epochs | จำนวนยุคฝึกอบรม |
| Random_seed | เมล็ดสุ่ม |
| logs_dir | ไดเรกทอรีสำหรับการบันทึก |
| train_log_filename | ชื่อไฟล์สำหรับการบันทึกรถไฟ |
| test_log_filename | ชื่อไฟล์สำหรับการบันทึกการทดสอบ |
| start_from_checkpoint | ตั้งค่าสถานะเพื่อเริ่มการฝึกอบรมจากจุดตรวจ |
| จุดตรวจสอบ _dir | ไดเรกทอรีสำหรับการจัดเก็บจุดตรวจของรุ่น |
| จุดตรวจสอบ _name | ชื่อไฟล์ของจุดตรวจ (สถานะรุ่น) |
| inference_model_name | ชื่อไฟล์ของโมเดลสำหรับการอนุมาน |
| การอนุมาน _dir | ไดเรกทอรีสำหรับการจัดเก็บตารางการอนุมาน .csv |
| dataloader.valid_split | จำนวนชุดย่อยการตรวจสอบความถูกต้อง |
| dataloader.num_workers | จำนวนคนงาน dataloader |
| DataSet.num_rows | จำนวนแถวที่อ่านได้ในชุดข้อมูลถ้า null อ่านแถวทั้งหมดในไฟล์ |
| dataSet.data_dir | ไดเรกทอรีสำหรับการจัดเก็บไฟล์รถไฟ/ทดสอบ/การอนุมาน |
| DataSet.train_path | ไดเรกทอรีสำหรับการจัดเก็บไฟล์ชุดข้อมูลรถไฟ .csv |
| DataSet.test_path | DIRECOTRY สำหรับการจัดเก็บไฟล์ชุดข้อมูลทดสอบ .csv |
เราขอแนะนำให้เปลี่ยนเฉพาะพารามิเตอร์ Theese:
num_gpu - หมายเลข ingeter บวกใด ๆ + {0} 0 ยืนสำหรับการฝึกอบรม / ทดสอบใน CPUsave_period_in_epochs - หมายเลขจำนวนเต็มบวกใด ๆ มาตรการในยุคtable_serialization_type - "column_wise" หรือ "table_wise"pretrained_model_name - ชื่อ Bert Shorcut จาก Huggingface Pytorch รุ่น Pretrainedbatch_size - หมายเลขจำนวนเต็มบวกใด ๆnum_epochs - หมายเลขจำนวนเต็มบวกใด ๆrandom_seed - หมายเลขจำนวนเต็มใด ๆstart_from_checkpoint - "true" หรือ "false"checkpoint_name - ชื่อของรุ่นใด ๆ ที่บันทึกไว้ในไดเรกทอรี checkpointinference_model_name - ชื่อของโมเดลใด ๆ ที่บันทึกไว้ในไดเรกทอรี checkpoint แต่เราขอแนะนำให้ใช้โมเดลที่ดีที่สุด: [model_best_f1_weighted.pt, model_best_f1_macro.pt, model_best_f1_micro.pt]dataloader.valid_split - จำนวนจริงภายในช่วง [0.0, 1.0] (0.0 หมายถึง 0 % ของชุดย่อยรถไฟ, 0.5 หมายถึง 50 % ของชุดย่อยรถไฟ) หรือหมายเลขจำนวนเต็มบวก (แสดงจำนวนชุดย่อยการตรวจสอบความถูกต้องคงที่)dataset.num_rows - "Null" หมายถึงการอ่านทั้งหมดในไฟล์ชุดข้อมูล จำนวนเต็มบวกหมายถึงจำนวนบรรทัดที่จะอ่านในไฟล์ของชุดข้อมูล ก่อนการฝึกอบรม / ทดสอบรูปแบบที่คุณต้อง:
├── src
│ ├── RuTaBERT
│ ├── RuTaBERT-Dataset
│ │ ├── move_dataset.sh
move_dataset.sh จากที่เก็บข้อมูลชุดข้อมูลเพื่อย้ายไฟล์ชุดข้อมูลไปยังไดเรกทอรี data Rutabert: RuTaBERT-Dataset$ ./move_dataset.shconfig.json ก่อนการฝึกอบรม Rutabert รองรับการฝึกอบรม / การทดสอบทั้งในและภายในคอนเทนเนอร์ Docker ยังรองรับ Slurm Workload Manager
RuTaBERT$ virtualenv venvหรือ
RuTaBERT$ python -m virtualenv venvRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 train.py 2> logs/error_train.log &&
python3 test.py 2> logs/error_test.logcheckpointlogs/ ไดเรกทอรี ( training_results.csv , train.log , test.log , error_train.log , error_test.log )ความต้องการ:
RuTaBERT$ sudo docker build -t rutabert .RuTaBERT$ sudo docker run -d --runtime=nvidia --gpus=all
--mount source=rutabert_logs,target=/app/rutabert/logs
--mount source=rutabert_checkpoints,target=/app/rutabert/checkpoints
rutabertRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_checkpoints/_data ./checkpointsRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_logs/_data ./logscheckpointlogs/ ไดเรกทอรี ( training_results.csv , train.log , test.log , error_train.log , error_test.log )RuTaBERT$ virtualenv venvหรือ
RuTaBERT$ python -m virtualenv venvRuTaBERT$ sbatch run.slurmRuTaBERT$ squeuecheckpointlogs/ ไดเรกทอรี ( train.log , test.log , error_train.log , error_test.log ) data/testRuTaBERT$ ./download.sh table_wiseหรือ
RuTaBERT$ ./download.sh column_wiseconfig.jsonRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 test.py 2> logs/error_test.loglogs/ ไดเรกทอรี ( test.log , error_test.log ) data/inferenceRuTaBERT$ ./download.sh table_wiseหรือ
RuTaBERT$ ./download.sh column_wiseconfig.jsonRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 inference.pydata/inference/result.csv