Modelo para resolver o problema da anotação do tipo de coluna com o BERT, treinado no conjunto de dados RWT-Rutabert.
O conjunto de dados RWT-Rutabert contém 1 441 349 colunas de tabelas russas da Wikipedia. Com cabeçalhos correspondentes a 170 tipos semânticos de dbpedia. Possui divisão de trem / teste fixo:
| Dividir | Colunas | Mesas | Avg. colunas por tabela |
|---|---|---|---|
| Teste | 115 448 | 55 080 | 2.096 |
| Trem | 1 325 901 | 633 426 | 2.093 |
Treinamos Rutabert com duas estratégias de serialização de tabela:
Resultados de referência no conjunto de dados RWT-Rutabert:
| Estratégia de serialização | micro-f1 | macro-f1 | ponderado-f1 |
|---|---|---|---|
| Multi-coluna | 0,962 | 0,891 | 0,9621 |
| Coluna vizinha | 0,964 | 0,904 | 0,9639 |
Parâmetros de treinamento:
| Parâmetro | Valor |
|---|---|
| Tamanho do lote | 32 |
| épocas | 30 |
| Função de perda | Entropia cruzada |
| Otimizador de GD | Adamw (LR = 5E-5, EPS = 1E-8) |
| GPU's | 4 NVIDIA A100 (80 GB) |
| semente aleatória | 2024 |
| Validação dividida | 5% |
?RuTaBERT
┣ checkpoints
┃ ┗ Saved PyTorch models `.pt`
┣ data
┃ ┣ inference
┃ ┃ ┗ Tabels to inference `.csv`
┃ ┣ test
┃ ┃ ┗ Test dataset files `.csv`
┃ ┣ train
┃ ┃ ┗ Train dataset files `.csv`
┃ ┗ Directory for storing dataset files.
┣ dataset
┃ ┗ Dataset wrapper classes, dataloaders
┣ logs
┃ ┗ Log files (train / test / error)
┣ model
┃ ┗ Model and metrics
┣ trainer
┃ ┗ Trainer
┣ utils
┃ ┗ Helper functions
┗ Entry points (train.py, test.py, inference.py), configuration, etc.
A configuração do modelo pode ser encontrada no arquivo config.json .
Os parâmetros de argumento da configuração estão listados abaixo:
| argumento | descrição |
|---|---|
| Num_labels | Número de rótulos usados para classificação |
| NUM_GPU | Número de GPUs para usar |
| save_period_in_epochs | Número caracterizando com que periodicidade o ponto de verificação é salvo (em épocas) |
| métricas | As métricas de classificação usadas são |
| pré -tereado_model_name | Nome de atalho Bert de Huggingface |
| tabela_Serialization_Type | Método de serializar uma tabela em uma sequência |
| batch_size | Tamanho do lote |
| num_epochs | Número de épocas de treinamento |
| Random_seed | Semente aleatória |
| logs_dir | Diretório para registro |
| TRIN_LOG_FILENAME | Nome do arquivo para registro de trem |
| test_log_filename | Nome do arquivo para registro de teste |
| start_from_checkpoint | Sinalizador para começar a treinar do ponto de verificação |
| Checkpoint_dir | Diretório para armazenar pontos de verificação do modelo |
| Checkpoint_name | Nome do arquivo de um ponto de verificação (estado do modelo) |
| inference_model_name | Nome do arquivo de um modelo para inferência |
| inference_dir | Diretório para armazenar tabelas de inferência .csv |
| DATALOADER.VALID_SPLIT | Quantidade de divisão de subconjunto de validação |
| dataloader.num_workers | Número de trabalhadores do Dataloader |
| DataSet.num_rows | Número de linhas legíveis no conjunto de dados, se null ler todas as linhas nos arquivos |
| DataSet.data_dir | Diretório para armazenar arquivos de trem/teste/inferência |
| DataSet.Train_Path | Diretório para armazenar arquivos de conjunto de dados de trem .csv |
| DataSet.test_path | Diretor para armazenar arquivos de conjunto de dados de teste .csv |
Recomendamos para mudar apenas esses parâmetros:
num_gpu - qualquer número de ingestão positivo + {0}. 0 Crie treinamento / teste na CPU.save_period_in_epochs - qualquer número inteiro positivo, medidas em épocas.table_serialization_type - "Column_wise" ou "tabela_wise".pretrained_model_name - nomes de Shorcut Bert de modelos Huggingface Pytorch pré -teria.batch_size - qualquer número inteiro positivo.num_epochs - qualquer número inteiro positivo.random_seed - qualquer número inteiro.start_from_checkpoint - "true" ou "false".checkpoint_name - Qualquer nome do modelo, salvo no diretório checkpoint .inference_model_name - qualquer nome do modelo, salvo no diretório checkpoint . Mas recomendamos usar os melhores modelos: [model_best_f1_weighted.pt, model_best_f1_macro.pt, model_best_f1_micro.pt].dataloader.valid_split - Número real dentro do intervalo [0,0, 1,0] (0,0 significa 0 % do subconjunto de trem, 0,5 significa 50 % do subconjunto de trem). Ou número inteiro positivo (denotando um número fixo de um subconjunto de validação).dataset.num_rows - "Null" significa ler todas as linhas nos arquivos do conjunto de dados. Inteiro positivo significa o número de linhas para ler nos arquivos do conjunto de dados. Antes de treinar / testar o modelo que você precisa:
├── src
│ ├── RuTaBERT
│ ├── RuTaBERT-Dataset
│ │ ├── move_dataset.sh
move_dataset.sh do repositório do conjunto de dados, para mover os arquivos do conjunto de dados para o diretório data Rutabert: RuTaBERT-Dataset$ ./move_dataset.shconfig.json antes do treinamento. O Rutabert suporta treinamento / teste localmente e dentro do contêiner do Docker. Também suporta o Slurm Workload Manager.
RuTaBERT$ virtualenv venvou
RuTaBERT$ python -m virtualenv venvRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 train.py 2> logs/error_train.log &&
python3 test.py 2> logs/error_test.logcheckpoint .logs/ diretório ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).Requisitos:
RuTaBERT$ sudo docker build -t rutabert .RuTaBERT$ sudo docker run -d --runtime=nvidia --gpus=all
--mount source=rutabert_logs,target=/app/rutabert/logs
--mount source=rutabert_checkpoints,target=/app/rutabert/checkpoints
rutabertRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_checkpoints/_data ./checkpointsRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_logs/_data ./logscheckpoint .logs/ diretório ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).RuTaBERT$ virtualenv venvou
RuTaBERT$ python -m virtualenv venvRuTaBERT$ sbatch run.slurmRuTaBERT$ squeuecheckpoint .logs/ diretório ( train.log , test.log , error_train.log , error_test.log ). data/test .RuTaBERT$ ./download.sh table_wiseou
RuTaBERT$ ./download.sh column_wiseconfig.json .RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 test.py 2> logs/error_test.loglogs/ diretório ( test.log , error_test.log ). data/inference .RuTaBERT$ ./download.sh table_wiseou
RuTaBERT$ ./download.sh column_wiseconfig.jsonRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 inference.pydata/inference/result.csv