Modell zur Lösung des Problems des Spaltentypanschlags mit Bert, geschult auf dem RWT-Rutabert-Datensatz.
Der RWT-Rutabert-Datensatz enthält 1 441 349 Spalten aus russischen Sprachwikipedia-Tabellen. Mit Header, die 170 dbpedia semantische Typen entsprechen. Es hat einen festen Zug / Test -Split festgelegt:
| Teilt | Spalten | Tische | Avg. Spalten pro Tabelle |
|---|---|---|---|
| Prüfen | 115 448 | 55 080 | 2.096 |
| Zug | 1 325 901 | 633 426 | 2.093 |
Wir haben Rutabert mit zwei Tischserialisierungsstrategien ausgebildet:
Benchmark-Ergebnisse zum RWT-Rutabert-Datensatz:
| Serialisierungsstrategie | Micro-F1 | MACRO-F1 | gewichtet-f1 |
|---|---|---|---|
| Multi-Säulen | 0,962 | 0,891 | 0,9621 |
| Benachbarte Säule | 0,964 | 0,904 | 0,9639 |
Trainingsparameter:
| Parameter | Wert |
|---|---|
| Chargengröße | 32 |
| Epochen | 30 |
| Verlustfunktion | Cross-Entropie |
| GD -Optimierer | ADAMW (LR = 5E-5, EPS = 1E-8) |
| GPUs | 4 Nvidia A100 (80 GB) |
| zufälliger Samen | 2024 |
| Validierungsaufteilung | 5% |
?RuTaBERT
┣ checkpoints
┃ ┗ Saved PyTorch models `.pt`
┣ data
┃ ┣ inference
┃ ┃ ┗ Tabels to inference `.csv`
┃ ┣ test
┃ ┃ ┗ Test dataset files `.csv`
┃ ┣ train
┃ ┃ ┗ Train dataset files `.csv`
┃ ┗ Directory for storing dataset files.
┣ dataset
┃ ┗ Dataset wrapper classes, dataloaders
┣ logs
┃ ┗ Log files (train / test / error)
┣ model
┃ ┗ Model and metrics
┣ trainer
┃ ┗ Trainer
┣ utils
┃ ┗ Helper functions
┗ Entry points (train.py, test.py, inference.py), configuration, etc.
Die Modellkonfiguration finden Sie in der Datei config.json .
Die Konfiguratoin -Argumentparameter sind unten aufgeführt:
| Argument | Beschreibung |
|---|---|
| num_labels | Anzahl der für die Klassifizierung verwendeten Etiketten |
| num_gpu | Anzahl der zu verwendenden GPUs |
| save_period_in_epochs | Zahlen charakterisieren mit der Periodizität des Kontrollpunkts (in Epochen) (in Epochen) |
| Metriken | Die verwendeten Klassifizierungsmetriken sind |
| vorbereitet_model_name | BERT -Verknüpfungsname von Suggingface |
| table_serialization_type | Methode zur Serialisierung einer Tabelle in eine Sequenz |
| batch_size | Chargengröße |
| num_epochs | Anzahl der Trainingspochen |
| random_seed | Zufälliger Samen |
| logs_dir | Verzeichnis für die Protokollierung |
| train_log_filename | Dateiname für die Zugprotokollierung |
| test_log_filename | Dateiname für die Testprotokollierung |
| start_from_checkpoint | Flag, um das Training vom Checkpoint aus zu starten |
| Checkpoint_dir | Verzeichnis für die Speicherung von Modellkontrollen des Modells |
| Checkpoint_Name | Dateiname eines Checkpoint (Modellstatus) |
| inference_model_name | Dateiname eines Modells für Inferenz |
| Inferenz_Dir | Verzeichnis zum Speichern von Inferenztabellen .csv |
| Dataloader.valid_split | Betrag der Validierungs -Teilmenge geteilt |
| Dataloader.num_worker | Anzahl der DataLoader -Mitarbeiter |
| dataSet.num_rows | Anzahl der lesbaren Zeilen im Datensatz, wenn null alle Zeilen in Dateien lesen |
| dataSet.data_dir | Verzeichnis zum Speichern von Zug-/Test-/Inferenzdateien |
| DataSet.Train_Path | Verzeichnis zum Speichern von Zugdatensatzdateien .csv |
| DataSet.test_path | Direcotry zum Speichern von Testdatensatzdateien .csv |
Wir empfehlen, nur die Parameter zu ändern:
num_gpu - jede positive Ingeter -Nummer + {0}. 0 Stand für das Training / Test auf der CPU.save_period_in_epochs - Jede positive Ganzzahlnummer, Maßnahmen in Epochen.table_serialization_type - "column_wise" oder "table_wise".pretrained_model_name - Bert Shorcut -Namen von Huggingface Pytorch Pretrainierte Modelle.batch_size - jede positive Ganzzahlnummer.num_epochs - jede positive Ganzzahlnummer.random_seed - jede ganzzahlige Nummer.start_from_checkpoint - "true" oder "false".checkpoint_name - Ein beliebiger Name des Modells, gespeichert im checkpoint -Verzeichnis.inference_model_name - Ein beliebiger Name des Modells, gespeichert im checkpoint -Verzeichnis. Wir empfehlen jedoch, die besten Modelle zu verwenden: [model_best_f1_weighted.pt, model_best_f1_macro.pt, model_best_f1_micro.pt].dataloader.valid_split - reelle Zahl innerhalb von Bereich [0,0, 1,0] (0,0 steht für 0 % der Zug -Teilmenge, 0,5 für 50 % der Zug -Teilmenge). Oder eine positive Ganzzahlnummer (die eine feste Anzahl einer Validierungs -Teilmenge bezeichnet).dataset.num_rows - "null" steht für das Lesen aller Zeilen in Dataset -Dateien. Positive Ganzzahl bedeutet die Anzahl der Zeilen, die in den Dateien des Datensatzes gelesen werden sollen. Vor dem Training / Testen des Modells, das Sie benötigen:
├── src
│ ├── RuTaBERT
│ ├── RuTaBERT-Dataset
│ │ ├── move_dataset.sh
move_dataset.sh Sie das Skript aus dem Dataset -Repository aus, um Datensatzdateien in das Rutabert data zu verschieben: RuTaBERT-Dataset$ ./move_dataset.shconfig.json vor dem Training. Rutabert unterstützt Schulungen / Tests lokal und im Docker -Container. Unterstützt auch Slurm Workload Manager.
RuTaBERT$ virtualenv venvoder
RuTaBERT$ python -m virtualenv venvRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 train.py 2> logs/error_train.log &&
python3 test.py 2> logs/error_test.logcheckpoint -Verzeichnis gespeichert.logs/ Verzeichnis ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).Anforderungen:
RuTaBERT$ sudo docker build -t rutabert .RuTaBERT$ sudo docker run -d --runtime=nvidia --gpus=all
--mount source=rutabert_logs,target=/app/rutabert/logs
--mount source=rutabert_checkpoints,target=/app/rutabert/checkpoints
rutabertRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_checkpoints/_data ./checkpointsRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_logs/_data ./logscheckpoint -Verzeichnis gespeichert.logs/ Verzeichnis ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).RuTaBERT$ virtualenv venvoder
RuTaBERT$ python -m virtualenv venvRuTaBERT$ sbatch run.slurmRuTaBERT$ squeuecheckpoint -Verzeichnis gespeichert.logs/ Verzeichnis ( train.log , test.log , error_train.log , error_test.log ). data/test platziert werden.RuTaBERT$ ./download.sh table_wiseoder
RuTaBERT$ ./download.sh column_wiseconfig.json testen soll.RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 test.py 2> logs/error_test.loglogs/ Verzeichnis ( test.log , error_test.log ). data/inference platziert werden.RuTaBERT$ ./download.sh table_wiseoder
RuTaBERT$ ./download.sh column_wiseconfig.json zu inferenzierenRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 inference.pydata/inference/result.csv sein. CSV