RuTaBERT
IVMEM2024
RWT-Rutabert DatasetでトレーニングされたBertによる列タイプの注釈の問題を解決するためのモデル。
RWT-Rutabertデータセットには、ロシア語のウィキペディアテーブルからの1 441 349列が含まれています。ヘッダーは170 dbpediaセマンティックタイプに一致します。固定された電車 /テストの分割があります:
| スプリット | 列 | テーブル | 平均。テーブルごとの列 |
|---|---|---|---|
| テスト | 115 448 | 55 080 | 2.096 |
| 電車 | 1 325 901 | 633 426 | 2.093 |
2つのテーブルシリアル化戦略でルタバートを訓練しました。
RWT-Rutabertデータセットのベンチマーク結果:
| シリアル化戦略 | Micro-F1 | マクロ-F1 | 加重-F1 |
|---|---|---|---|
| マルチコラム | 0.962 | 0.891 | 0.9621 |
| 隣接する列 | 0.964 | 0.904 | 0.9639 |
トレーニングパラメーター:
| パラメーター | 価値 |
|---|---|
| バッチサイズ | 32 |
| エポック | 30 |
| 損失関数 | クロスエントロピー |
| GDオプティマイザー | Adamw(LR = 5E-5、EPS = 1E-8) |
| GPU | 4 Nvidia A100(80 GB) |
| ランダムシード | 2024 |
| 検証分割 | 5% |
?RuTaBERT
┣ checkpoints
┃ ┗ Saved PyTorch models `.pt`
┣ data
┃ ┣ inference
┃ ┃ ┗ Tabels to inference `.csv`
┃ ┣ test
┃ ┃ ┗ Test dataset files `.csv`
┃ ┣ train
┃ ┃ ┗ Train dataset files `.csv`
┃ ┗ Directory for storing dataset files.
┣ dataset
┃ ┗ Dataset wrapper classes, dataloaders
┣ logs
┃ ┗ Log files (train / test / error)
┣ model
┃ ┗ Model and metrics
┣ trainer
┃ ┗ Trainer
┣ utils
┃ ┗ Helper functions
┗ Entry points (train.py, test.py, inference.py), configuration, etc.
モデル構成は、file config.jsonに記載されています。
configuratoin引数パラメーターを以下に示します。
| 口論 | 説明 |
|---|---|
| num_labels | 分類に使用されるラベルの数 |
| num_gpu | 使用するGPUの数 |
| save_period_in_epochs | チェックポイントが保存される周期性で特徴付けられる数(エポックで) |
| メトリック | 使用される分類メトリックは次のとおりです |
| pretrained_model_name | Huggingfaceのバートショートカット名 |
| table_serialization_type | テーブルをシーケンスにシリアル化する方法 |
| batch_size | バッチサイズ |
| num_epochs | トレーニングエポックの数 |
| random_seed | ランダムシード |
| logs_dir | ロギング用のディレクトリ |
| train_log_filename | 電車のロギングのファイル名 |
| test_log_filename | テストロギングのファイル名 |
| start_from_checkpoint | チェックポイントからトレーニングを開始するフラグ |
| checkpoint_dir | モデルのチェックポイントを保存するためのディレクトリ |
| checkpoint_name | チェックポイントのファイル名(モデル状態) |
| Inference_model_name | 推論のためのモデルのファイル名 |
| Inference_dir | 推論表を保存するためのディレクトリ.csv |
| dataloader.valid_split | 検証サブセット分割の量 |
| dataloader.num_workers | Dataloader労働者の数 |
| dataset.num_rows | データセット内の読み取り可能な行の数、 nullファイル内のすべての行を読み取る場合 |
| dataset.data_dir | 列車/テスト/推論ファイルを保存するためのディレクトリ |
| dataset.train_path | 列車データセットファイルを保存するためのディレクトリ.csv |
| dataset.test_path | テストデータセットファイルを保存するためのdirecotry .csv |
これらのパラメーターのみを変更することをお勧めします。
num_gpu任意の正のイングテーター数 + {0}。 0 CPUでのトレーニング /テストの略。save_period_in_epochsポジティブな整数数、エポックの測定。table_serialization_type "column_wise"または "table_wise"。pretrained_model_name -huggingface pytorch事前処理されたモデルからのbert shorcut名。batch_size正の整数番号。num_epochs正の整数数。random_seed整数番号。start_from_checkpoint 「true」または「false」。checkpoint_name checkpointディレクトリに保存されたモデルの名前。inference_model_name checkpointディレクトリに保存されたモデルの名前。ただし、[Model_best_f1_weighted.pt、model_best_f1_macro.pt、model_best_f1_micro.pt]を使用することをお勧めします。dataloader.valid_split範囲内の実数[0.0、1.0](0.0は列車のサブセットの0%を表し、0.5は列車のサブセットの50%に耐えます)。または正の整数番号(検証サブセットの固定数を示します)。dataset.num_rows 「null」は、データセットファイルのすべての行を読み取るための略です。正の整数とは、データセットのファイルで読み取る行の数を意味します。 モデルをトレーニング /テストする前に、次のことが必要です。
├── src
│ ├── RuTaBERT
│ ├── RuTaBERT-Dataset
│ │ ├── move_dataset.sh
move_dataset.shを実行して、データセットファイルをrutabert dataディレクトリに移動します。 RuTaBERT-Dataset$ ./move_dataset.shconfig.jsonファイルを構成します。 Rutabertは、ローカルおよび内部のDockerコンテナのトレーニング /テストをサポートしています。また、SluRMワークロードマネージャーもサポートしています。
RuTaBERT$ virtualenv venvまたは
RuTaBERT$ python -m virtualenv venvRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 train.py 2> logs/error_train.log &&
python3 test.py 2> logs/error_test.logcheckpointディレクトリに保存されます。logs/ directory( training_results.csv 、 train.log 、 test.log 、 error_train.log 、 error_test.log )にあります。要件:
RuTaBERT$ sudo docker build -t rutabert .RuTaBERT$ sudo docker run -d --runtime=nvidia --gpus=all
--mount source=rutabert_logs,target=/app/rutabert/logs
--mount source=rutabert_checkpoints,target=/app/rutabert/checkpoints
rutabertRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_checkpoints/_data ./checkpointsRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_logs/_data ./logscheckpointディレクトリに保存されます。logs/ directory( training_results.csv 、 train.log 、 test.log 、 error_train.log 、 error_test.log )にあります。RuTaBERT$ virtualenv venvまたは
RuTaBERT$ python -m virtualenv venvRuTaBERT$ sbatch run.slurmRuTaBERT$ squeuecheckpointディレクトリに保存されます。logs/ directory( train.log 、 test.log 、 error_train.log 、 error_test.log )にあります。 data/testディレクトリに配置されたデータを確認してください。RuTaBERT$ ./download.sh table_wiseまたは
RuTaBERT$ ./download.sh column_wiseconfig.jsonでテストするモデルを構成します。RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 test.py 2> logs/error_test.loglogs/ directory( test.log 、 error_test.log )になります。 data/inferenceディレクトリに配置されたデータを確認してください。RuTaBERT$ ./download.sh table_wiseまたは
RuTaBERT$ ./download.sh column_wiseconfig.jsonで推論するモデルを構成しますRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 inference.pydata/inference/result.csvになります