이것은 공식적으로 지원되는 Google 제품이 아닙니다.
이것은 종이 "AMOS : 모델 지향적 척도로 적응 무게가 부패한 Adam 스타일의 최적화기"의 소스 코드입니다.
Optax 라이브러리와 호환되는 Optimizer와 JAX의 기계 학습 프로그램에 대한 T5X 호환 체크 포인트를 관리하기위한 tf.Estimator -Like Interface가있는 가벼운 라이브러리 인 Jestimator 인 AMOS를 구현합니다.
pip install jestimator
Jestimator Lib에 구현 된 AMOS Optimizer를 설치합니다.
이 AMOS 구현은 기계 학습 연구를 위해 자동 차별화를 갖춘 고성능 수치 컴퓨팅 라이브러리 인 JAX와 함께 사용됩니다. AMOS의 API는 JAX 최적화 라이브러리 라이브러리 인 Optax와 호환됩니다 (AMOS는 가까운 시일 내에 Optax에 통합 될 것입니다).
사용법을 입증하기 위해 AMOS를 MNIST에 적용 할 것입니다. Flax의 공식 MNIST 예제를 기반으로하며 여기에서 Jupyter 노트에서 코드를 찾을 수 있습니다.
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
TrainState 객체는 모델 매개 변수와 최적화 상태를 유지하고 파일로 확인할 수 있습니다.
이 기능에서 모델과 Optimizer를 만듭니다.
Optimizer의 경우 여기에서 AMOS를 사용합니다. 다음과 같은 하이퍼 파라미터가 설정됩니다.
글로벌 학습 속도는 일반적으로 1/sqrt (n)로 설정되며, 여기서 n은 훈련 데이터의 배치 수입니다. MNIST의 경우 60K 교육 예제가 있으며 배치 크기는 32입니다. 따라서 Learning_rate = 1/sqrt (60000/32).
모델 별 'ETA_FN'에는 변수 이름과 모양이 주어지면 해당 변수의 예상 스케일을 나타내는 플로트를 반환하는 함수가 필요합니다. 가까운 시일 내에 모델링 코드 에서이 'ETA_FN'을 자동으로 계산할 수있는 라이브러리가 있기를 바랍니다. 그러나 지금은 수동으로 지정해야합니다.
amos_helper.params_fn_from_assign_map () 도우미 함수를 사용하여 antart_map에서 'eta_fn'을 생성 할 수 있습니다. antart_map은 regex 규칙을 값 또는 간단한 Python 표현식에 매핑하는 dict입니다. 변수의 이름과 일치하는 첫 번째 Regex 규칙을 찾아 값을 반환하기 위해 필요한 경우 Python 표현식을 평가합니다. 아래의 예를 참조하십시오.
'shape_fn'도 마찬가지로 변수 이름과 모양이 주어지면 해당 슬롯 변수에 대해 감소 된 모양을 반환하는 함수가 필요합니다. amos_helper.params_fn_from_assign_map () 도우미 함수를 사용하여 antart_map에서 'shape_fn'을 생성 할 수 있습니다.
'베타'는 그라디언트 사각형의 평균 실행에 대한 지수 붕괴 속도입니다. 우리는 여기에서 0.98로 설정했습니다.
'clip_value'는 그라디언트 클리핑 값으로 손실 함수의 크기와 일치해야합니다. 손실 함수가 크로스-엔트로피의 합인 경우 라벨 수의 SQRT로 'clip_value'를 설정해야합니다.
하이퍼 파라미터에 대한 자세한 내용은 당사 논문을 참조하십시오.
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
Jax의 @jit 데코레이터를 사용하여 더 나은 성능을 위해 기능을 컴파일하십시오.
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
Jax의 @jit 데코레이터를 사용하여 더 나은 성능을 위해 기능을 컴파일하십시오.
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
교육 루프를 실행하고 테스트 세트에서 평가하십시오.
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
9 개의 에포크 후, 우리는 99.26 테스트 정확도를 얻어야합니다. 당신이 그것을 만들었다면, 축하합니다!
Jestimator를 사용하면 위의 MNIST 예제와 유사하지만 "기본"섹션에 대한 코드를 작성하지 않고 모델을 구축 할 수 있습니다. Jestimator는 모델의 진입 지점 역할을하며 기차/평가 온/평가-비트 및 사례 최신/예측 모드에서 자동으로 검사 점을 처리하고 프로파일 링, 텐서 보드 및 로깅을 설정합니다.
또한 Jestimator는 여러 TPU 포드에서 매우 큰 모델을 훈련시키는 데 필요한 모델 파티셔닝을 지원합니다. 대규모 멀티 포드 모델에 적합한 분산 방식으로 체크 포인트를 저장하고 복원하는 T5X 호환 체크 포인트 형식을 지원합니다.
Jestimator로 모델을 실행하려면 T5x 및 Flaxformer를 설치해야합니다.
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
그런 다음이 저장소를 복제하여 Jestimator 코드를 얻습니다.
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
이제 장난감 선형 회귀 모델을 테스트 할 수 있습니다.
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
우리는이 mnist 예제를 제공하여 Jestimator로 모델링 코드를 작성하는 방법을 보여줍니다. 위의 예와 매우 흡사하지만 큰 장점으로 구성 모델링을 동적으로 설정하기 위해 전역 플래그 및 데이터 세트에서 정보를 수집하기 위해 구성 객체가 전달됩니다. 이를 통해 모델을 다른 데이터 세트에 쉽게 적용 할 수 있습니다. 예를 들어, 코드를 수정하지 않고 명령 줄 인수를 변경하여 즉시 EMNIST 또는 EUROSAT 데이터 세트를 시도 할 수 있습니다.
다음 명령을 사용하면 MNIST를 훈련시키기 위해 작업을 시작하고 100 단계마다 로그인 한 후 체크 포인트를 $ home/Experiments/MNIST/Models에 저장할 수 있습니다.
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
한편, 우리는 $ home/실험/mnist/model 폴더를 모니터링하고 MNIST 테스트 세트를 평가하고 가장 높은 정확도로 모델을 저장하기 위해 작업을 시작할 수 있습니다.
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
동시에, 우리는 텐서 보드를 시작하여 프로세스를 모니터링 할 수 있습니다.
tensorboard --logdir $HOME/experiments/mnist/models
다음 명령을 사용하여 PTB에서 단일 레이어 LSTM을 훈련시킬 수 있습니다.
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
평가 :
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
단일 GPU 머신에서 실행하는 데 적합합니다.
다음은 GCP (Google Cloud Platform)의 TPU를 사용하여 사전 훈련 및 미세 조정 Bert와 같은 모델을위한 간단한 가이드입니다. 로컬로 설치하지 않고 Google Cloud 콘솔을 통해 가상 시스템에 연결하여 설정이없는 웹 브라우저로 시작할 수 있습니다. 이것이 처음이라면, 하나는 명령을 무료로 시도하기에 충분한 크레딧으로 덮여 있습니다.