감독되지 않은 데이터 확대 또는 UDA는 다양한 언어 및 비전 작업에 대한 최첨단 결과를 달성하는 반 감독 학습 방법입니다.
20 개의 라벨이 붙은 예제로 UDA는 25,000 개의 라벨이 붙은 예제에 대해 훈련 된 IMDB의 이전 최신 ART보다 우수합니다.
| 모델 | 라벨이 붙은 예의 수 | 오류율 |
|---|---|---|
| 혼합 VAT (Prev. Sota) | 25,000 | 4.32 |
| 버트 | 25,000 | 4.51 |
| UDA | 20 | 4.20 |
CIFAR-10에서 최첨단 방법의 오류율의 30% 이상을 4,000 개의 라벨링 된 예제로, 1,000 개의 라벨링 된 예제가있는 SVHN을 줄입니다.
| 모델 | Cifar-10 | svhn |
|---|---|---|
| ICT (이전 소타) | 7.66 ± .17 | 3.53 ± .07 |
| UDA | 4.31 ± .08 | 2.28 ± .10 |
10% 라벨이 붙은 데이터로 ImageNet의 상당한 개선으로 이어집니다.
| 모델 | 상단 1 정확도 | 상위 5 개 정확도 |
|---|---|---|
| RESNET-50 | 55.09 | 77.26 |
| UDA | 68.78 | 88.80 |
UDA는 반 감독 학습 방법으로, 라벨이 붙은 예제의 필요성을 줄이고 표지되지 않은 예제를 더 잘 활용합니다.
우리는 다음을 출시하고 있습니다.
이 저장소의 모든 코드는 GPU 및 Google Cloud TPU와 함께 제공됩니다.
코드는 Python 2.7 및 Tensorflow 1.13에서 테스트됩니다. TensorFlow를 설치 한 후 다음 명령을 실행하여 종속성을 설치하십시오.
pip install --user absl-py우리는 모든 원래 예제에 대해 100 개의 증강 된 예를 생성합니다. 증강 된 모든 데이터를 다운로드하려면 이미지 디렉토리로 이동하여 실행하십시오.
AUG_COPY=100
bash scripts/download_cifar10.sh ${AUG_COPY}모든 증강 된 데이터에 120g 디스크 공간이 필요합니다. 공간을 절약하기 위해 Aug_copy를 30과 같은 더 작은 숫자로 설정할 수 있습니다.
또는, 당신은 실행하여 강화 된 예제를 직접 생성 할 수 있습니다.
AUG_COPY=100
bash scripts/preprocess.sh --aug_copy= ${AUG_COPY}GPU 명령 :
# UDA accuracy:
# 4000: 95.68 +- 0.08
# 2000: 95.27 +- 0.14
# 1000: 95.25 +- 0.10
# 500: 95.20 +- 0.09
# 250: 94.57 +- 0.96
bash scripts/run_cifar10_gpu.sh --aug_copy= ${AUG_COPY} # UDA accuracy:
# 4000: 97.72 +- 0.10
# 2000: 97.80 +- 0.06
# 1000: 97.77 +- 0.07
# 500: 97.73 +- 0.09
# 250: 97.28 +- 0.40
bash scripts/run_svhn_gpu.sh --aug_copy= ${AUG_COPY} IMDB의 영화 검토 텍스트는 많은 분류 작업보다 길기 때문에 더 긴 시퀀스 길이를 사용하면 더 나은 성능을 제공합니다. 시퀀스 길이는 BERT를 사용할 때 TPU/GPU 메모리에 의해 제한됩니다 (버트의 메모리 외 문제 참조). 따라서, 우리는 더 짧은 시퀀스 길이와 더 작은 배치 크기로 실행할 스크립트를 제공합니다.
11GB 메모리가있는 GPU에서 BERT베이스로 UDA를 실행하려면 텍스트 디렉토리로 이동하여 다음 명령을 실행하십시오.
# Set a larger max_seq_length if your GPU has a memory larger than 11GB
MAX_SEQ_LENGTH=128
# Download data and pretrained BERT checkpoints
bash scripts/download.sh
# Preprocessing
bash scripts/prepro.sh --max_seq_length= ${MAX_SEQ_LENGTH}
# Baseline accuracy: around 68%
bash scripts/run_base.sh --max_seq_length= ${MAX_SEQ_LENGTH}
# UDA accuracy: around 90%
# Set a larger train_batch_size to achieve better performance if your GPU has a larger memory.
bash scripts/run_base_uda.sh --train_batch_size=8 --max_seq_length= ${MAX_SEQ_LENGTH}
이 논문의 최상의 성능은 Max_Seq_length 512를 사용하고 감독되지 않은 데이터에 대한 Bert 대형 Finetuned로 초기화하여 달성됩니다. Google Cloud TPU v3-32 POD에 액세스 할 수있는 경우 :
MAX_SEQ_LENGTH=512
# Download data and pretrained BERT checkpoints
bash scripts/download.sh
# Preprocessing
bash scripts/prepro.sh --max_seq_length= ${MAX_SEQ_LENGTH}
# UDA accuracy: 95.3% - 95.9%
bash train_large_ft_uda_tpu.sh우선, 다음 종속성을 설치하십시오.
pip install --user nltk
python -c " import nltk; nltk.download('punkt') "
pip install --user tensor2tensor==1.13.4다음 명령은 제공된 예제 파일을 변환합니다. 단락을 문장으로 자동 분할하고 영어 문장을 프랑스어로 번역 한 다음 영어로 다시 번역합니다. 마지막으로, 그것은 말의 문장을 단락으로 구성합니다. back_translate 디렉토리로 이동하여 실행하십시오.
bash download.sh
bash run.shbash 파일에는 변수 샘플링 _temp가 있습니다. 그것은 역설의 다양성과 품질을 제어하는 데 사용됩니다. 샘플링 _temp를 증가 시키면 다양성이 증가하지만 품질이 악화 될 것입니다. 놀랍게도, 다양성은 우리가 시도한 많은 작업에서 품질보다 더 중요합니다.
샘플링 _temp를 0.7, 0.8 및 0.9로 설정하는 것이 좋습니다. 작업이 소음에 매우 강력한 경우 Sampling_Temp = 0.9 또는 0.8로 인해 성능이 향상됩니다. 작업이 노이즈에 강력하지 않은 경우 샘플링 온도를 0.7 또는 0.6으로 설정하는 것이 더 나을 것입니다.
큰 파일로 번역을 다시하려면 run.sh에서 복제품 및 worker_id 인수를 변경할 수 있습니다. 예를 들어, replicas = 3 인 경우 데이터를 세 부분으로 나누고 각 run.sh는 worker_id에 따라 한 부분 만 처리합니다.
UDA는 박스 외부에서 작동하며 광범위한 하이퍼 파라미터 튜닝이 필요하지는 않지만 실제로 성능을 발휘하기 위해 하이퍼 파 램터에 대한 제안은 다음과 같습니다.
코드의 상당 부분은 Bert와 Randaugment에서 가져옵니다. 감사해요!
UDA를 사용하는 경우이 백서를 인용하십시오.
@article{xie2019unsupervised,
title={Unsupervised Data Augmentation for Consistency Training},
author={Xie, Qizhe and Dai, Zihang and Hovy, Eduard and Luong, Minh-Thang and Le, Quoc V},
journal={arXiv preprint arXiv:1904.12848},
year={2019}
}
이미지에 UDA를 사용하는 경우이 백서를 인용하십시오.
@article{cubuk2019randaugment,
title={RandAugment: Practical data augmentation with no separate search},
author={Cubuk, Ekin D and Zoph, Barret and Shlens, Jonathon and Le, Quoc V},
journal={arXiv preprint arXiv:1909.13719},
year={2019}
}
이것은 공식적으로 지원되는 Google 제품이 아닙니다.