Ce n'est pas un produit Google officiellement pris en charge.
Il s'agit du code source du papier "Amos: un optimiseur de style Adam avec une décroissance de poids adaptative vers l'échelle orientée modèle".
Il implémente AMOS , un optimiseur compatible avec la bibliothèque Optax, et Jestimator , une bibliothèque de poids légère avec une interface de type tf.Estimator pour gérer les points de contrôle compatibles T5X pour les programmes d'apprentissage automatique dans JAX, que nous utilisons pour exécuter des expériences dans l'article.
pip install jestimator
Il installera l'optimiseur AMOS implémenté dans le Jestimator Lib.
Cette implémentation d'AMOS est utilisée avec Jax, une bibliothèque informatique numérique haute performance avec différenciation automatique, pour la recherche sur l'apprentissage automatique. L'API d'AMOS est compatible avec Optax, une bibliothèque d'optimistes JAX (, espérons-le, Amos sera intégré à Optax dans un avenir proche).
Afin de démontrer l'utilisation, nous appliquerons AMOS à MNIST. Il est basé sur l'exemple MNIST officiel de Flax, et vous pouvez trouver le code dans un carnet de jupyter ici.
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
Un objet TrainState conserve les paramètres du modèle et les états d'optimiseur et peut être vérifié dans les fichiers.
Nous créons le modèle et l'optimiseur dans cette fonction.
Pour l'optimiseur, nous utilisons AMOS ici. Les hyper-paramètres suivants sont définis:
Le taux d'apprentissage global est généralement réglé sur le 1 / sqrt (n), où n est le nombre de lots dans les données de formation. Pour MNIST, nous avons des exemples de formation de 60k et la taille du lot est de 32. So Learning_Rate = 1 / SQRT (60000/32).
Le «ETA_FN» spécifique au modèle nécessite une fonction qui, étant donné un nom et une forme de variable, renvoie un flotteur indiquant l'échelle attendue de cette variable. Espérons que dans un avenir proche, nous aurons des bibliothèques qui pourront calculer automatiquement ce «Eta_FN» à partir du code de modélisation; Mais pour l'instant, nous devons le spécifier manuellement.
On peut utiliser la fonction d'assistance AMOS_HELPER.PARAMS_FN_FROM_ASSIGN_MAP () pour créer 'eta_fn' à partir d'un ADSIGN_MAP. Un ADSIGN_MAP est un dict qui mappe les règles Regex à une valeur ou une expression python simple. Il trouvera la première règle Regex qui correspond au nom d'une variable et évaluera l'expression Python si nécessaire pour renvoyer la valeur. Voir notre exemple ci-dessous.
Le 'Shape_Fn' nécessite également une fonction qui, étant donné un nom et une forme variables, renvoie une forme réduite pour les variables de fente correspondantes. Nous pouvons utiliser la fonction d'assistance AMOS_HELPER.PARAMS_FN_FROM_ASSIGN_MAP () pour créer 'Shape_Fn' à partir d'un ADSIGN_MAP.
«Beta» est le taux de désintégration exponentiel pour l'exécution de la moyenne des carrés de gradient. Nous l'avons réglé à 0,98 ici.
«Clip_value» est la valeur d'écrêtage du gradient, qui devrait correspondre à l'amplitude de la fonction de perte. Si la fonction de perte est une somme d'entrée croisée, nous devons définir «clip_value» au SQRT du nombre d'étiquettes.
Veuillez vous référer à notre article pour plus de détails sur les hyper-paramètres.
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
Utilisez le décorateur @jit de Jax pour compiler la fonction pour de meilleures performances.
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
Utilisez le décorateur @jit de Jax pour compiler la fonction pour de meilleures performances.
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
Exécutez la boucle d'entraînement et évaluez sur les tests.
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
Après 9 époques, nous devrions obtenir une précision de test de 99,26. Si vous l'avez fait, félicitations!
Avec Jestimator, vous pouvez construire votre modèle principalement similaire à l'exemple MNIST ci-dessus, mais sans écrire de code pour la section "principale"; Jestimator servira de point d'entrée pour votre modèle, gérera automatiquement la pointe de contrôle dans un mode de train / evalu-once / evalise-formation-et-épave du meilleur / prédire et configurez le profilage, le tensorboard et la journalisation.
De plus, Jestimator prend en charge le partitionnement du modèle qui est nécessaire pour la formation de très grands modèles sur plusieurs pods TPU. Il prend en charge un format de point de contrôle compatible T5X qui enregistre et restaure les points de contrôle de manière distribuée, ce qui convient aux grands modèles multi-pod.
Afin d'exécuter des modèles avec Jestimator, nous devons installer T5X et Flaxformer:
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
Ensuite, clonez ce dépôt pour obtenir le code de jestimateur:
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
Maintenant, nous pouvons tester un modèle de régression linéaire jouet:
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
Nous fournissons cet exemple MNIST pour démontrer comment écrire du code de modélisation avec Jestimator. Cela ressemble beaucoup à l'exemple ci-dessus, mais avec un grand avantage, un objet de configuration est passé pour collecter des informations à partir de drapeaux globaux et à l'ensemble de données, afin de configurer dynamiquement la modélisation. Cela facilite l'application du modèle à différents ensembles de données; Par exemple, on peut immédiatement essayer les ensembles de données Emnist ou Eurosat simplement en modifiant un argument en ligne de commande, sans modifier le code.
Avec la commande suivante, nous pouvons démarrer un emploi pour nous entraîner sur MNIST, enregistrer toutes les 100 étapes et enregistrer les points de contrôle sur $ home / expériences / MNIST / Modèles:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
Pendant ce temps, nous pouvons démarrer un emploi pour surveiller le dossier $ home / expériences / MNIST / modèles, évaluer sur un ensemble de tests MNIST et enregistrer le modèle avec la plus grande précision:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
En même temps, nous pouvons démarrer un tensorboard pour surveiller le processus:
tensorboard --logdir $HOME/experiments/mnist/models
Nous pouvons utiliser la commande suivante pour former une seule couche LSTM sur PTB:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
et évaluer:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
Il convient pour fonctionner sur une machine à GPU unique.
Voici quelques guides simples pour les modèles de pré-transfert et de type BERT, à l'aide de TPUS sur Google Cloud Platform (GCP). On peut commencer avec un navigateur Web avec une configuration zéro, en se connectant à une machine virtuelle via Google Cloud Console, sans rien installer localement. Si c'est la première fois, on est couvert par suffisamment de crédits pour essayer les commandes gratuitement.