RuTaBERT
IVMEM2024
用於用BERT解決列類型註釋問題的模型,在RWT-Rutabert數據集上訓練。
RWT-Rutabert數據集包含1 441 349列Wikipedia表。標題與170個DBPEDIA語義類型相匹配。它具有固定的火車 /測試拆分:
| 分裂 | 列 | 表 | avg。每張表列 |
|---|---|---|---|
| 測試 | 115 448 | 55 080 | 2.096 |
| 火車 | 1 325 901 | 633 426 | 2.093 |
我們通過兩種表序列化培訓了魯塔伯特:
RWT-Rutabert數據集上的基準結果:
| 序列化策略 | Micro-F1 | 宏F1 | 加權f1 |
|---|---|---|---|
| 多列 | 0.962 | 0.891 | 0.9621 |
| 相鄰列 | 0.964 | 0.904 | 0.9639 |
培訓參數:
| 範圍 | 價值 |
|---|---|
| 批量大小 | 32 |
| 時代 | 30 |
| 損失功能 | 跨腸道 |
| GD優化器 | adamw(LR = 5E-5,EPS = 1E-8) |
| GPU的 | 4 NVIDIA A100(80 GB) |
| 隨機種子 | 2024 |
| 驗證拆分 | 5% |
?RuTaBERT
┣ checkpoints
┃ ┗ Saved PyTorch models `.pt`
┣ data
┃ ┣ inference
┃ ┃ ┗ Tabels to inference `.csv`
┃ ┣ test
┃ ┃ ┗ Test dataset files `.csv`
┃ ┣ train
┃ ┃ ┗ Train dataset files `.csv`
┃ ┗ Directory for storing dataset files.
┣ dataset
┃ ┗ Dataset wrapper classes, dataloaders
┣ logs
┃ ┗ Log files (train / test / error)
┣ model
┃ ┗ Model and metrics
┣ trainer
┃ ┗ Trainer
┣ utils
┃ ┗ Helper functions
┗ Entry points (train.py, test.py, inference.py), configuration, etc.
模型配置可以在文件config.json中找到。
configuratoin參數參數如下:
| 爭論 | 描述 |
|---|---|
| num_labels | 用於分類的標籤數量 |
| num_gpu | 使用的GPU數量 |
| save_period_in_epochs | 數字表徵了保存檢查點的周期性(以時期) |
| 指標 | 使用的分類指標是 |
| 預處理_model_name | Bert快捷方式來自Huggingface |
| table_serialization_type | 將表序列化為序列的方法 |
| batch_size | 批量大小 |
| num_epochs | 訓練時期的數量 |
| Random_seed | 隨機種子 |
| logs_dir | 記錄目錄 |
| train_log_filename | 火車記錄的文件名 |
| test_log_filename | 測試記錄的文件名 |
| start_from_checkpoint | 標誌開始從檢查站開始訓練 |
| checkpoint_dir | 存儲模型檢查點的目錄 |
| checkpoint_name | 檢查點的文件名(模型狀態) |
| penperion_model_name | 推理模型的文件名 |
| 推理_dir | 存儲推理表的目錄.csv |
| dataloader.valid_split | 驗證子集拆分量 |
| dataloader.num_workers | 數據加載工人的數量 |
| dataset.num_rows | 數據集中可讀行的數量,如果null讀取文件中的所有行 |
| dataset.data_dir | 存儲火車/測試/推理文件的目錄 |
| dataset.train_path | 存儲火車數據集文件的目錄.csv |
| dataset.test_path | 用於存儲測試數據集文件.csv |
我們建議僅更改ESE參數:
num_gpu任何正iggeter編號 + {0}。 0代表CPU培訓 /測試。save_period_in_epochs任何正整數,時期的度量。table_serialization_type “ column_wise”或“ table_wise”。pretrained_model_name -huggingface pytorch預處理模型中的bert shorcut名稱。batch_size任何正整數。num_epochs任何正整數號。random_seed任何整數編號。start_from_checkpoint “ true”或“ false”。checkpoint_name模型的任何名稱,保存在checkpoint目錄中。inference_model_name任何模型的名稱,保存在checkpoint目錄中。但是我們建議使用最佳模型:[model_best_f1_weighted.pt,model_best_f1_macro.pt,model_best_f1_micro.pt]。dataloader.valid_split範圍內的實際數字[0.0,1.0](0.0代表火車子集的0%,0.5代表火車子集的50%。或正整數數(表示固定數量的驗證子集)。dataset.num_rows “ null”代表讀取數據集文件中的所有行。正整數是指在數據集的文件中讀取的行數。 在培訓 /測試模型之前,您需要:
├── src
│ ├── RuTaBERT
│ ├── RuTaBERT-Dataset
│ │ ├── move_dataset.sh
move_dataset.sh ,將數據集文件移動到rutabert data目錄: RuTaBERT-Dataset$ ./move_dataset.shconfig.json文件。 Rutabert支持當地和Docker內部的培訓 /測試。還支持Slurm Workload Manager。
RuTaBERT$ virtualenv venv或者
RuTaBERT$ python -m virtualenv venvRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 train.py 2> logs/error_train.log &&
python3 test.py 2> logs/error_test.logcheckpoint目錄中。logs/目錄中( training_results.csv , train.log , test.log , error_train.log , error_test.log )。要求:
RuTaBERT$ sudo docker build -t rutabert .RuTaBERT$ sudo docker run -d --runtime=nvidia --gpus=all
--mount source=rutabert_logs,target=/app/rutabert/logs
--mount source=rutabert_checkpoints,target=/app/rutabert/checkpoints
rutabertRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_checkpoints/_data ./checkpointsRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_logs/_data ./logscheckpoint目錄中。logs/目錄中( training_results.csv , train.log , test.log , error_train.log , error_test.log )。RuTaBERT$ virtualenv venv或者
RuTaBERT$ python -m virtualenv venvRuTaBERT$ sbatch run.slurmRuTaBERT$ squeuecheckpoint目錄中。logs/目錄中( train.log , test.log , error_train.log , error_test.log )。 data/test目錄中。RuTaBERT$ ./download.sh table_wise或者
RuTaBERT$ ./download.sh column_wiseconfig.json中測試的模型。RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 test.py 2> logs/error_test.loglogs/目錄中( test.log , error_test.log )。 data/inference目錄中。RuTaBERT$ ./download.sh table_wise或者
RuTaBERT$ ./download.sh column_wiseconfig.json中配置哪種模型對推斷RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 inference.pydata/inference/result.csv中