Este não é um produto do Google oficialmente suportado.
Este é o código-fonte do artigo "Amos: um otimizador no estilo Adam com decaimento de peso adaptável em direção à escala orientada para o modelo".
Ele implementa Amos , um otimizador compatível com a biblioteca OptTax e o Jestimator , uma biblioteca leve com uma interface do tipo tf.Estimator para gerenciar pontos de verificação compatíveis com T5X para programas de aprendizado de máquina no JAX, que usamos para executar experimentos no papel.
pip install jestimator
Ele instalará o otimizador AMOS implementado no Jestimator Lib.
Essa implementação do AMOS é usada com o JAX, uma biblioteca de computação numérica de alto desempenho com diferenciação automática, para pesquisa de aprendizado de máquina. A API do AMOS é compatível com o Optax, uma biblioteca de otimizadores JAX (espero que o AMOS seja integrado ao Optax em um futuro próximo).
Para demonstrar o uso, aplicaremos o AMOS ao mnist. É baseado no exemplo oficial do Flax e você pode encontrar o código em um notebook Jupyter aqui.
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
Um objeto TrainState mantém os parâmetros do modelo e os estados do otimizador e pode ser posicionado em arquivos.
Criamos o modelo e o otimizador nesta função.
Para o otimizador, usamos Amos aqui. Os seguintes hiper-parâmetros estão definidos:
A taxa de aprendizado global geralmente é definida como 1/SQRT (N), onde n é o número de lotes nos dados de treinamento. Para o MNIST, temos exemplos de treinamento de 60k e o tamanho do lote é 32. SO APRENDIZADO_RATE = 1/SQRT (60000/32).
O 'eta_fn' específico do modelo requer uma função que, dada um nome e forma de variável, retorna um flutuador indicando a escala esperada dessa variável. Felizmente, em um futuro próximo, teremos bibliotecas que possam calcular automaticamente esse 'eta_fn' do código de modelagem; Mas, por enquanto, temos que especificá -lo manualmente.
Pode -se usar o AMOS_HELPER.PARAMS_FN_FOM_ASSIGN_MAP () função auxiliar para criar 'eta_fn' a partir de um atribuído_map. Um cession_map é um ditado que mapeia as regras do regex para um valor ou expressão simples de python. Ele encontrará a primeira regra regex que corresponde ao nome de uma variável e avaliará a expressão do Python, se necessário, para retornar o valor. Veja o nosso exemplo abaixo.
O 'shape_fn' requer da mesma forma uma função que, dado um nome e forma de variável, retorna uma forma reduzida para as variáveis de slot correspondentes. Podemos usar o AMOS_HELPER.PARAMS_FN_FROM_ASSIGN_MAP () função auxiliar para criar 'shape_fn' a partir de um atribuído_map também.
'Beta' é a taxa de decaimento exponencial para a média de funcionamento dos quadrados de gradiente. Nós o definimos para 0,98 aqui.
'clip_value' é o valor de recorte do gradiente, que deve corresponder à magnitude da função de perda. Se a função de perda for uma soma da entropia cruzada, devemos definir 'clip_value' para o SQRT do número de rótulos.
Consulte o nosso artigo para obter mais detalhes dos hiper-parâmetros.
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
Use o decorador @Jit da JAX para compilar a função para melhor desempenho.
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
Use o decorador @Jit da JAX para compilar a função para melhor desempenho.
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
Execute o loop de treinamento e avalie no conjunto de testes.
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
Após 9 épocas, devemos obter 99,26 precisão do teste. Se você fez isso, parabéns!
Com o Jestimator, você pode construir seu modelo principalmente semelhante ao exemplo mnist acima, mas sem escrever código para a seção "principal"; O Jestimator servirá como ponto de entrada para o seu modelo, lidará automaticamente com o Ponto de verificação em um modo de trem/avaliar/avaliar-while-while e salvar o melhor/prever o modo e configurar perfis, tensorboard e log.
Além disso, o Jestimator suporta a partição do modelo, necessária para o treinamento de modelos muito grandes em várias vagens de TPU. Ele suporta um formato de ponto de verificação compatível com T5X que salva e restaura os pontos de verificação de uma maneira distribuída, adequada para grandes modelos de vários cod.
Para executar modelos com Jestimator, precisamos instalar o T5X e o Flaxformer:
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
Em seguida, clone este repositório para obter o código Jestimator:
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
Agora, podemos testar um modelo de regressão linear de brinquedo:
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
Fornecemos este exemplo mnist para demonstrar como escrever o código de modelagem com o Jestimator. É muito parecido com o exemplo acima, mas com uma grande vantagem que um objeto de configuração é transmitido para coletar informações dos sinalizadores globais e do conjunto de dados, a fim de configurar dinamicamente a modelagem. Isso facilita a aplicação do modelo a diferentes conjuntos de dados; Por exemplo, pode-se experimentar imediatamente os conjuntos de dados emnist ou eurosat simplesmente alterando um argumento da linha de comando, sem modificar o código.
Com o comando a seguir, podemos iniciar um trabalho para treinar no mnist, registrar a cada 100 etapas e salvar os pontos de verificação em $ home/experimentos/mnist/modelos:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
Enquanto isso, podemos iniciar um trabalho para monitorar a pasta $ Home/Experimentos/MNIST/Modelos, avaliar o conjunto de testes MNIST e salvar o modelo com a maior precisão:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
Ao mesmo tempo, podemos iniciar um quadro de tensor para monitorar o processo:
tensorboard --logdir $HOME/experiments/mnist/models
Podemos usar o seguinte comando para treinar uma única camada LSTM no PTB:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
e avaliar:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
É adequado para executar em uma máquina de GPU único.
Aqui estão alguns guias simples para modelos de pré-treino e tune bert, usando TPUs na plataforma do Google Cloud (GCP). Pode -se começar com um navegador da web com configuração zero, conectando -se a uma máquina virtual via Google Cloud Console, sem instalar nada localmente. Se for a primeira vez, um é coberto por créditos suficientes para experimentar os comandos gratuitamente.