Este no es un producto de Google compatible oficialmente.
Este es el código fuente del documento "Amos: un optimizador de estilo Adam con descomposición de peso adaptativo hacia la escala orientada al modelo".
Implementa AMOS , un optimizador compatible con la biblioteca OPTAX, y Jestimator , una biblioteca de peso ligero con una interfaz similar a tf.Estimator para administrar puntos de control compatibles con T5X para programas de aprendizaje automático en JAX, que usamos para ejecutar experimentos en el documento.
pip install jestimator
Instalará el optimizador AMOS implementado en el jestimator lib.
Esta implementación de AMOS se utiliza con Jax, una biblioteca de computación numérica de alto rendimiento con diferenciación automática, para la investigación del aprendizaje automático. La API de AMOS es compatible con OPTAX, una biblioteca de optimizadores de Jax (con suerte, Amos se integrará en OPTAX en el futuro cercano).
Para demostrar el uso, aplicaremos AMOS a Mnist. Se basa en el ejemplo MNIST oficial de Flax, y puede encontrar el código en un cuaderno Jupyter aquí.
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
Un objeto TrainState mantiene los parámetros del modelo y los estados optimizadores, y se puede poner en cuenta en archivos.
Creamos el modelo y el optimizador en esta función.
Para el optimizador, usamos AMOS aquí. Se establecen los siguientes hiper-parámetros:
La tasa de aprendizaje global generalmente se establece en 1/SQRT (N), donde N es el número de lotes en los datos de capacitación. Para MNIST, tenemos ejemplos de capacitación de 60k y el tamaño de lotes es 32. Entonces Learning_Rate = 1/SQRT (60000/32).
El 'eta_fn' específico del modelo requiere una función que, dado un nombre y forma variable, devuelva un flotador que indique la escala esperada de esa variable. Esperemos que en el futuro cercano tengamos bibliotecas que puedan calcular automáticamente este 'eta_fn' del código de modelado; Pero por ahora tenemos que especificarlo manualmente.
Uno puede usar la función de ayuda para amos_helper.params_fn_from_assign_map () para crear 'eta_fn' desde un asign_map. Un asign_map es un dict que mapea reglas regex a un valor o expresión simple de pitón. Encontrará la primera regla regex que coincide con el nombre de una variable y evalúa la expresión de Python si es necesario para devolver el valor. Vea nuestro ejemplo a continuación.
El 'sape_fn' requiere de manera similar una función que, dado un nombre y forma de variable, devuelva una forma reducida para las variables de ranura correspondientes. Podemos usar la función de ayuda para amos_helper.params_fn_from_assign_map () para crear 'shape_fn' de un asign_map también.
'beta' es la tasa de descomposición exponencial para ejecutar el promedio de cuadrados de gradiente. Lo establecemos en 0.98 aquí.
'Clip_Value' es el valor de recorte de gradiente, que debería coincidir con la magnitud de la función de pérdida. Si la función de pérdida es una suma de entropía cruzada, entonces debemos establecer 'clip_value' en el sqrt del número de etiquetas.
Consulte nuestro artículo para obtener más detalles de los hiperparametros.
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
Use el decorador de Jax @jit para compilar la función para un mejor rendimiento.
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
Use el decorador de Jax @jit para compilar la función para un mejor rendimiento.
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
Ejecute el bucle de entrenamiento y evalúe en el conjunto de pruebas.
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
Después de 9 épocas, deberíamos obtener 99.26 precisión de la prueba. Si lo lograste, ¡felicidades!
Con Jestimator, puede construir su modelo principalmente similar al ejemplo de MNIST anterior, pero sin escribir código para la sección "principal"; Jestimator servirá como punto de entrada para su modelo, manejará automáticamente el punto de control en un modo de trenes/evaluación/evaluación/evaluación y salvación del mejor/predicto, y configure el perfil, la placa tensorial y el registro.
Además, Jestimator admite la partición del modelo que se requiere para capacitar modelos muy grandes en múltiples vainas de TPU. Admite un formato de punto de control compatible con T5X que guarda y restaura los puntos de control de manera distribuida, que es adecuada para grandes modelos de múltiples pods.
Para ejecutar modelos con Jestimator, necesitamos instalar T5X y Flaxformer:
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
Luego, clone este repositorio para obtener el código Jestimator:
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
Ahora, podemos probar un modelo de regresión lineal de juguete:
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
Proporcionamos este ejemplo de MNIST para demostrar cómo escribir código de modelado con Jestimator. Es muy parecido al ejemplo anterior, pero con una gran ventaja que se pasa un objeto de configuración para recopilar información de los indicadores globales y el conjunto de datos, para configurar dinámicamente el modelado. Esto hace que sea más fácil aplicar el modelo a diferentes conjuntos de datos; Por ejemplo, uno puede probar inmediatamente los conjuntos de datos EMNIST o EUROSAT simplemente cambiando un argumento de línea de comandos, sin modificar el código.
Con el siguiente comando, podemos comenzar un trabajo para entrenar en MNIST, registrar cada 100 pasos y guardar los puntos de control en $ Home/Experiments/MNIST/Models:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
Mientras tanto, podemos comenzar un trabajo para monitorear la carpeta $ Home/Experiments/MNIST/Models, evaluar en el conjunto de pruebas MNIST y guardar el modelo con la más alta precisión:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
Al mismo tiempo, podemos iniciar una placa tensor para monitorear el proceso:
tensorboard --logdir $HOME/experiments/mnist/models
Podemos usar el siguiente comando para entrenar una sola capa LSTM en PTB:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
y evaluar:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
Es adecuado para ejecutar en una sola máquina GPU.
Aquí hay algunas guías simples para pre-entrenar y ajustar los modelos tipo Bert, utilizando TPUS en Google Cloud Platform (GCP). Uno puede comenzar con un navegador web con configuración cero, conectándose a una máquina virtual a través de Google Cloud Console, sin instalar nada localmente. Si esta es la primera vez, uno está cubierto por suficientes créditos para probar los comandos de forma gratuita.