Модель для решения проблемы аннотации типа столбца с помощью BERT, обученной набору данных RWT-Rutabert.
Набор данных RWT-Rutabert содержит 1 441 349 столбцов из российского языка таблиц Википедии. С заголовками, соответствующими 170 семантическим типам Dbpedia. Он имеет фиксированное разделение поезда / тестирования:
| Расколоть | Колонны | Столы | Ав. столбцы на таблицу |
|---|---|---|---|
| Тест | 115 448 | 55 080 | 2.096 |
| Тренироваться | 1 325 901 | 633 426 | 2.093 |
Мы обучили Рутаберт с двумя стратегиями сериализации таблицы:
Результаты эталона на наборе данных RWT-Rutabert:
| Стратегия сериализации | Микро-F1 | макро-F1 | взвешенный-F1 |
|---|---|---|---|
| Многоколон | 0,962 | 0,891 | 0,9621 |
| Соседняя колонка | 0,964 | 0,904 | 0,9639 |
Параметры обучения:
| Параметр | Ценить |
|---|---|
| Размер партии | 32 |
| эпохи | 30 |
| Функция потерь | Перекрестная энтропия |
| GD Optimizer | ADAMW (LR = 5E-5, EPS = 1E-8) |
| Графические графические процессоры | 4 nvidia a100 (80 ГБ) |
| Случайное семя | 2024 |
| валидация разделения | 5% |
?RuTaBERT
┣ checkpoints
┃ ┗ Saved PyTorch models `.pt`
┣ data
┃ ┣ inference
┃ ┃ ┗ Tabels to inference `.csv`
┃ ┣ test
┃ ┃ ┗ Test dataset files `.csv`
┃ ┣ train
┃ ┃ ┗ Train dataset files `.csv`
┃ ┗ Directory for storing dataset files.
┣ dataset
┃ ┗ Dataset wrapper classes, dataloaders
┣ logs
┃ ┗ Log files (train / test / error)
┣ model
┃ ┗ Model and metrics
┣ trainer
┃ ┗ Trainer
┣ utils
┃ ┗ Helper functions
┗ Entry points (train.py, test.py, inference.py), configuration, etc.
Конфигурацию модели можно найти в файле config.json .
Параметры аргумента конфигуратоина перечислены ниже:
| аргумент | описание |
|---|---|
| num_labels | Количество меток, используемых для классификации |
| num_gpu | Количество графических процессоров для использования |
| save_period_in_epochs | Число, характеризующее, какую периодичность сохраняется контрольная точка (в эпохи) |
| метрики | Используемые метрики классификации |
| Предварительный_модель_name | Bert Shortcut name от uggingface |
| table_serialization_type | Метод сериализации таблицы в последовательность |
| batch_size | Размер партии |
| num_epochs | Количество тренировочных эпох |
| random_seed | Случайное семя |
| logs_dir | Каталог для регистрации |
| train_log_filename | Имя файла для ведения журнала поездов |
| test_log_filename | Имя файла для журнала тестирования |
| start_from_checkpoint | Флаг, чтобы начать обучение с контрольно -пропускного пункта |
| CheckPoint_DIR | Каталог для хранения контрольных точек модели |
| CheckPoint_Name | Имя файла контрольной точки (состояние модели) |
| sepence_model_name | Имя файла модели для вывода |
| sepence_dir | Каталог для хранения таблиц вывода .csv |
| dataLoader.valid_split | Количество валидационного подмножества разделения |
| DataLoader.num_workers | Количество работников DataLoader |
| DateSet.num_rows | Количество читаемых строк в наборе данных, если null Reade все строки в файлах |
| DataSet.data_DIR | Справочник для хранения файлов поезда/тестирования/вывода |
| DataSet.train_Path | Каталог для хранения файлов наборов данных поезда .csv |
| DateSet.test_path | Direcotry для хранения файлов наборов данных тестовых данных .csv |
Мы рекомендуем изменять только те, которые параметры:
num_gpu - любой положительный номер индекса + {0}. 0 Останьте обучение / тестирование на процессоре.save_period_in_epochs - Любой положительный целый ряд, меры в эпохи.table_serialization_type - "column_wise" или "table_wise".pretrained_model_name - имена Bert Shorcut от Huggingface Pytorch, предварительно предоставленных моделями.batch_size - любой положительный целый ряд.num_epochs - любое положительное целое число.random_seed - любое целочисленное число.start_from_checkpoint - "true" или "false".checkpoint_name - любое имя модели, сохраненное в каталоге checkpoint .inference_model_name - любое имя модели, сохраненное в каталоге checkpoint . Но мы рекомендуем использовать лучшие модели: [MODEL_BEST_F1_WEELED.PT, MODEL_BEST_F1_MACRO.PT, MODEL_BEST_F1_MICRO.PT].dataloader.valid_split - Реальное число в диапазоне [0,0, 1,0] (0,0 означает 0 % подмножества поезда, 0,5 стоят 50 % подмножества поезда). Или положительное целое число (обозначение фиксированного числа подмножества проверки).dataset.num_rows - "null" означает чтение всех строк в файлах набора данных. Положительное целое число означает количество строк для чтения в файлах набора данных. Перед тренировкой / тестированием модели вам нужно:
├── src
│ ├── RuTaBERT
│ ├── RuTaBERT-Dataset
│ │ ├── move_dataset.sh
move_dataset.sh из репозитория набора данных, чтобы перемещать файлы наборов в каталог data Rutabert: RuTaBERT-Dataset$ ./move_dataset.shconfig.json перед обучением. Рутаберт поддерживает обучение / тестирование на локальном и внутри контейнера Docker. Также поддерживает Slurm Workload Manager.
RuTaBERT$ virtualenv venvили
RuTaBERT$ python -m virtualenv venvRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 train.py 2> logs/error_train.log &&
python3 test.py 2> logs/error_test.logcheckpoint .logs/ каталоге ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).Требования:
RuTaBERT$ sudo docker build -t rutabert .RuTaBERT$ sudo docker run -d --runtime=nvidia --gpus=all
--mount source=rutabert_logs,target=/app/rutabert/logs
--mount source=rutabert_checkpoints,target=/app/rutabert/checkpoints
rutabertRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_checkpoints/_data ./checkpointsRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_logs/_data ./logscheckpoint .logs/ каталоге ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).RuTaBERT$ virtualenv venvили
RuTaBERT$ python -m virtualenv venvRuTaBERT$ sbatch run.slurmRuTaBERT$ squeuecheckpoint .logs/ Directory ( train.log , test.log , error_train.log , error_test.log ). data/test .RuTaBERT$ ./download.sh table_wiseили
RuTaBERT$ ./download.sh column_wiseconfig.json .RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 test.py 2> logs/error_test.loglogs/ каталоге ( test.log , error_test.log ). data/inference .RuTaBERT$ ./download.sh table_wiseили
RuTaBERT$ ./download.sh column_wiseconfig.jsonRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 inference.pydata/inference/result.csv