RuTaBERT
IVMEM2024
用于用BERT解决列类型注释问题的模型,在RWT-Rutabert数据集上训练。
RWT-Rutabert数据集包含1 441 349列Wikipedia表。标题与170个DBPEDIA语义类型相匹配。它具有固定的火车 /测试拆分:
| 分裂 | 列 | 表 | avg。每张表列 |
|---|---|---|---|
| 测试 | 115 448 | 55 080 | 2.096 |
| 火车 | 1 325 901 | 633 426 | 2.093 |
我们通过两种表序列化培训了鲁塔伯特:
RWT-Rutabert数据集上的基准结果:
| 序列化策略 | Micro-F1 | 宏F1 | 加权f1 |
|---|---|---|---|
| 多列 | 0.962 | 0.891 | 0.9621 |
| 相邻列 | 0.964 | 0.904 | 0.9639 |
培训参数:
| 范围 | 价值 |
|---|---|
| 批量大小 | 32 |
| 时代 | 30 |
| 损失功能 | 跨肠道 |
| GD优化器 | adamw(LR = 5E-5,EPS = 1E-8) |
| GPU的 | 4 NVIDIA A100(80 GB) |
| 随机种子 | 2024 |
| 验证拆分 | 5% |
?RuTaBERT
┣ checkpoints
┃ ┗ Saved PyTorch models `.pt`
┣ data
┃ ┣ inference
┃ ┃ ┗ Tabels to inference `.csv`
┃ ┣ test
┃ ┃ ┗ Test dataset files `.csv`
┃ ┣ train
┃ ┃ ┗ Train dataset files `.csv`
┃ ┗ Directory for storing dataset files.
┣ dataset
┃ ┗ Dataset wrapper classes, dataloaders
┣ logs
┃ ┗ Log files (train / test / error)
┣ model
┃ ┗ Model and metrics
┣ trainer
┃ ┗ Trainer
┣ utils
┃ ┗ Helper functions
┗ Entry points (train.py, test.py, inference.py), configuration, etc.
模型配置可以在文件config.json中找到。
configuratoin参数参数如下:
| 争论 | 描述 |
|---|---|
| num_labels | 用于分类的标签数量 |
| num_gpu | 使用的GPU数量 |
| save_period_in_epochs | 数字表征了保存检查点的周期性(以时期) |
| 指标 | 使用的分类指标是 |
| 预处理_model_name | Bert快捷方式来自Huggingface |
| table_serialization_type | 将表序列化为序列的方法 |
| batch_size | 批量大小 |
| num_epochs | 训练时期的数量 |
| Random_seed | 随机种子 |
| logs_dir | 记录目录 |
| train_log_filename | 火车记录的文件名 |
| test_log_filename | 测试记录的文件名 |
| start_from_checkpoint | 标志开始从检查站开始训练 |
| checkpoint_dir | 存储模型检查点的目录 |
| checkpoint_name | 检查点的文件名(模型状态) |
| penperion_model_name | 推理模型的文件名 |
| 推理_dir | 存储推理表的目录.csv |
| dataloader.valid_split | 验证子集拆分量 |
| dataloader.num_workers | 数据加载工人的数量 |
| dataset.num_rows | 数据集中可读行的数量,如果null读取文件中的所有行 |
| dataset.data_dir | 存储火车/测试/推理文件的目录 |
| dataset.train_path | 存储火车数据集文件的目录.csv |
| dataset.test_path | 用于存储测试数据集文件.csv |
我们建议仅更改ESE参数:
num_gpu任何正iggeter编号 + {0}。 0代表CPU培训 /测试。save_period_in_epochs任何正整数,时期的度量。table_serialization_type “ column_wise”或“ table_wise”。pretrained_model_name -huggingface pytorch预处理模型中的bert shorcut名称。batch_size任何正整数。num_epochs任何正整数号。random_seed任何整数编号。start_from_checkpoint “ true”或“ false”。checkpoint_name模型的任何名称,保存在checkpoint目录中。inference_model_name任何模型的名称,保存在checkpoint目录中。但是我们建议使用最佳模型:[model_best_f1_weighted.pt,model_best_f1_macro.pt,model_best_f1_micro.pt]。dataloader.valid_split范围内的实际数字[0.0,1.0](0.0代表火车子集的0%,0.5代表火车子集的50%。或正整数数(表示固定数量的验证子集)。dataset.num_rows “ null”代表读取数据集文件中的所有行。正整数是指在数据集的文件中读取的行数。 在培训 /测试模型之前,您需要:
├── src
│ ├── RuTaBERT
│ ├── RuTaBERT-Dataset
│ │ ├── move_dataset.sh
move_dataset.sh ,将数据集文件移动到rutabert data目录: RuTaBERT-Dataset$ ./move_dataset.shconfig.json文件。 Rutabert支持当地和Docker内部的培训 /测试。还支持Slurm Workload Manager。
RuTaBERT$ virtualenv venv或者
RuTaBERT$ python -m virtualenv venvRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 train.py 2> logs/error_train.log &&
python3 test.py 2> logs/error_test.logcheckpoint目录中。logs/目录中( training_results.csv , train.log , test.log , error_train.log , error_test.log )。要求:
RuTaBERT$ sudo docker build -t rutabert .RuTaBERT$ sudo docker run -d --runtime=nvidia --gpus=all
--mount source=rutabert_logs,target=/app/rutabert/logs
--mount source=rutabert_checkpoints,target=/app/rutabert/checkpoints
rutabertRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_checkpoints/_data ./checkpointsRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_logs/_data ./logscheckpoint目录中。logs/目录中( training_results.csv , train.log , test.log , error_train.log , error_test.log )。RuTaBERT$ virtualenv venv或者
RuTaBERT$ python -m virtualenv venvRuTaBERT$ sbatch run.slurmRuTaBERT$ squeuecheckpoint目录中。logs/目录中( train.log , test.log , error_train.log , error_test.log )。 data/test目录中。RuTaBERT$ ./download.sh table_wise或者
RuTaBERT$ ./download.sh column_wiseconfig.json中测试的模型。RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 test.py 2> logs/error_test.loglogs/目录中( test.log , error_test.log )。 data/inference目录中。RuTaBERT$ ./download.sh table_wise或者
RuTaBERT$ ./download.sh column_wiseconfig.json中配置哪种模型对推断RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 inference.pydata/inference/result.csv中