Model untuk menyelesaikan masalah anotasi jenis kolom dengan Bert, dilatih pada dataset RWT-Rutabert.
Dataset RWT-Rutabert berisi 1 441 349 kolom dari tabel wikipedia bahasa Rusia. Dengan header yang cocok dengan 170 tipe semantik dbpedia. Ini memiliki split kereta / tes tetap:
| Membelah | Kolom | Tabel | Rata -rata. kolom per tabel |
|---|---|---|---|
| Tes | 115 448 | 55 080 | 2.096 |
| Kereta | 1 325 901 | 633 426 | 2.093 |
Kami melatih Rutabert dengan dua strategi serialisasi meja:
Hasil Benchmark pada Dataset RWT-Rutabert:
| Strategi serialisasi | Mikro-F1 | makro-f1 | bobot-f1 |
|---|---|---|---|
| Multi-kolom | 0.962 | 0.891 | 0.9621 |
| Kolom tetangga | 0.964 | 0,904 | 0.9639 |
Parameter pelatihan:
| Parameter | Nilai |
|---|---|
| Ukuran batch | 32 |
| zaman | 30 |
| Fungsi kerugian | Cross-entropy |
| GD Optimizer | ADAMW (LR = 5E-5, EPS = 1E-8) |
| GPU | 4 Nvidia A100 (80 GB) |
| Benih acak | 2024 |
| validasi split | 5% |
?RuTaBERT
┣ checkpoints
┃ ┗ Saved PyTorch models `.pt`
┣ data
┃ ┣ inference
┃ ┃ ┗ Tabels to inference `.csv`
┃ ┣ test
┃ ┃ ┗ Test dataset files `.csv`
┃ ┣ train
┃ ┃ ┗ Train dataset files `.csv`
┃ ┗ Directory for storing dataset files.
┣ dataset
┃ ┗ Dataset wrapper classes, dataloaders
┣ logs
┃ ┗ Log files (train / test / error)
┣ model
┃ ┗ Model and metrics
┣ trainer
┃ ┗ Trainer
┣ utils
┃ ┗ Helper functions
┗ Entry points (train.py, test.py, inference.py), configuration, etc.
Konfigurasi model dapat ditemukan di file config.json .
Parameter argumen configuratoin tercantum di bawah ini:
| argumen | keterangan |
|---|---|
| num_labels | Jumlah label yang digunakan untuk klasifikasi |
| num_gpu | Jumlah GPU yang akan digunakan |
| save_period_in_epochs | Angka yang mengkarakterisasi dengan periodisitas apa pos pemeriksaan disimpan (dalam zaman) |
| metrik | Metrik klasifikasi yang digunakan adalah |
| pretrained_model_name | Nama pintasan Bert dari huggingface |
| table_serialization_type | Metode serialisasi tabel menjadi urutan |
| Batch_Size | Ukuran batch |
| num_epochs | Jumlah zaman pelatihan |
| acak_seed | Benih acak |
| logs_dir | Direktori untuk logging |
| train_log_filename | Nama file untuk penebangan kereta api |
| test_log_filename | Nama file untuk penebangan tes |
| start_from_checkpoint | Bendera untuk memulai pelatihan dari pos pemeriksaan |
| checkpoint_dir | Direktori untuk menyimpan pos pemeriksaan model |
| checkpoint_name | Nama file dari pos pemeriksaan (status model) |
| inference_model_name | Nama file model untuk inferensi |
| inference_dir | Direktori untuk menyimpan tabel inferensi .csv |
| dataloader.valid_split | Jumlah subset validasi split |
| dataloader.num_workers | Jumlah pekerja dataloader |
| dataset.num_rows | Jumlah baris yang dapat dibaca dalam dataset, jika null membaca semua baris dalam file |
| dataset.data_dir | Direktori untuk menyimpan file kereta/tes/inferensi |
| dataset.train_path | Direktori untuk Menyimpan File Dataset Kereta .csv |
| dataset.test_path | Direcotry untuk menyimpan file dataset uji .csv |
Kami merekomendasikan untuk hanya mengubah parameter:
num_gpu - angka ingeter positif + {0}. 0 Berdiri untuk Pelatihan / Pengujian di CPU.save_period_in_epochs - nomor integer positif apa pun, ukuran dalam zaman.table_serialization_type - "column_wise" atau "table_wise".pretrained_model_name - nama Bert Shorcut dari model pretrained pytorch huggingface.batch_size - nomor bilangan bulat positif.num_epochs - nomor integer positif apa pun.random_seed - nomor bilangan bulat apa pun.start_from_checkpoint - "true" atau "false".checkpoint_name - Nama model apa pun, disimpan di direktori checkpoint .inference_model_name - nama model apa pun, disimpan di direktori checkpoint . Tapi kami sarankan untuk menggunakan model terbaik: [model_best_f1_weighted.pt, model_best_f1_macro.pt, model_best_f1_micro.pt].dataloader.valid_split - bilangan real dalam kisaran [0,0, 1.0] (0,0 berdiri untuk 0 % dari subset kereta api, 0,5 berarti 50 % dari subset kereta). Atau nomor integer positif (menunjukkan jumlah tetap dari subset validasi).dataset.num_rows - "NULL" adalah singkatan dari Read All Lines in Dataset Files. Integer positif berarti jumlah baris untuk dibaca di file dataset. Sebelum melatih / menguji model yang Anda butuhkan:
├── src
│ ├── RuTaBERT
│ ├── RuTaBERT-Dataset
│ │ ├── move_dataset.sh
move_dataset.sh dari Dataset Repository, untuk memindahkan file dataset ke direktori data rutabert: RuTaBERT-Dataset$ ./move_dataset.shconfig.json sebelum pelatihan. Rutabert mendukung pelatihan / pengujian secara lokal dan dalam wadah Docker. Juga mendukung Slurm Workload Manager.
RuTaBERT$ virtualenv venvatau
RuTaBERT$ python -m virtualenv venvRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 train.py 2> logs/error_train.log &&
python3 test.py 2> logs/error_test.logcheckpoint .logs/ direktori ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).Persyaratan:
RuTaBERT$ sudo docker build -t rutabert .RuTaBERT$ sudo docker run -d --runtime=nvidia --gpus=all
--mount source=rutabert_logs,target=/app/rutabert/logs
--mount source=rutabert_checkpoints,target=/app/rutabert/checkpoints
rutabertRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_checkpoints/_data ./checkpointsRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_logs/_data ./logscheckpoint .logs/ direktori ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).RuTaBERT$ virtualenv venvatau
RuTaBERT$ python -m virtualenv venvRuTaBERT$ sbatch run.slurmRuTaBERT$ squeuecheckpoint .logs/ direktori ( train.log , test.log , error_train.log , error_test.log ). data/test Direktori.RuTaBERT$ ./download.sh table_wiseatau
RuTaBERT$ ./download.sh column_wiseconfig.json .RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 test.py 2> logs/error_test.loglogs/ direktori ( test.log , error_test.log ). data/inference .RuTaBERT$ ./download.sh table_wiseatau
RuTaBERT$ ./download.sh column_wiseconfig.jsonRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 inference.pydata/inference/result.csv