Dies ist kein offiziell unterstütztes Google -Produkt.
Dies ist der Quellcode für das Papier "AMOS: Ein Optimierer im Adam-Stil mit adaptivem Gewichtsverfall in Richtung modellorientierter Skala".
Es implementiert Amos , einen mit der Optax-Bibliothek kompatiblen Optimierer, und einen Jestimator , eine leichte Bibliothek mit einer tf.Estimator -ähnlichen Schnittstelle, um T5X-kompatible Kontrollpunkte für maschinelle Lernprogramme in JAX zu verwalten, mit denen wir Experimente im Papier ausführen können.
pip install jestimator
Es wird den im Jestimator LIB implementierten AMOS -Optimierer installiert.
Diese Implementierung von AMOS wird mit JAX, einer leistungsstarken numerischen Computerbibliothek mit automatischer Differenzierung, für maschinelle Lernforschung verwendet. Die API von AMOS ist mit Optax, einer Bibliothek von JAX -Optimierern, kompatibel (hoffentlich wird AMOS in naher Zukunft in Optax integriert).
Um die Verwendung zu demonstrieren, werden wir Amos auf MNIST anwenden. Es basiert auf dem offiziellen MNIST -Beispiel von Flax und finden Sie den Code in einem Jupyter -Notizbuch hier.
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
Ein TrainState -Objekt hält die Modellparameter und Optimiererzustände und kann in Dateien aktiviert werden.
Wir erstellen das Modell und den Optimierer in dieser Funktion.
Für den Optimierer verwenden wir AMOs hier. Die folgenden Hyperparameter sind festgelegt:
Die globale Lernrate wird normalerweise auf 1/sqrt (n) festgelegt, wobei n die Anzahl der Chargen in den Trainingsdaten ist. Für MNIST haben wir 60k -Trainingsbeispiele und die Chargengröße beträgt 32. So Learning_Rate = 1/SQRT (60000/32).
Der modellspezifische 'eta_fn' erfordert eine Funktion, die bei einem variablen Namen und einer variablen Form einen Float zurückgibt, der die erwartete Skala dieser Variablen angibt. Hoffentlich haben wir in naher Zukunft Bibliotheken, die diese 'eta_fn' automatisch aus dem Modellierungscode berechnen können. Aber jetzt müssen wir es manuell angeben.
Man kann die Funktion amos_helper.params_fn_from_assign_map () Helfer verwenden, um 'eta_fn' aus einem Astromenten_Map zu erstellen. Ein Astroming_Map ist ein Diktat, das die Regex -Regeln für einen Wert oder einen einfachen Python -Ausdruck ordnet. Es wird die erste Regex -Regel gefunden, die dem Namen einer Variablen entspricht, und den Python -Ausdruck bei Bedarf bewertet, um den Wert zurückzugeben. Siehe unser Beispiel unten.
Die 'Shape_FN' erfordert in ähnlicher Weise eine Funktion, die bei einem variablen Namen und einer variablen Form eine reduzierte Form für die entsprechenden Schlitzvariablen zurückgibt. Wir können die Funktion von AMOS_HELPER.PARAMS_FN_FROM_ASGE_MAP () Helfer verwenden, um 'Shape_Fn' auch aus einem Astroming_Map zu erstellen.
'Beta' ist die exponentielle Zerfallsrate für den Durchschnitt der Gradientenquadrate. Wir haben es hier auf 0,98 gesetzt.
'clip_value' ist der Gradienten -Clipping -Wert, der mit der Größe der Verlustfunktion übereinstimmt. Wenn die Verlustfunktion eine Summe der Querentropie ist, sollten wir "clip_value" auf die Anzahl der Etiketten einstellen.
Weitere Informationen zu den Hyper-Parametern finden Sie in unserer Zeitung.
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
Verwenden Sie den @jit-Dekorator von JAX, um die Funktion für eine bessere Leistung zu kompilieren.
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
Verwenden Sie den @jit-Dekorator von JAX, um die Funktion für eine bessere Leistung zu kompilieren.
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
Führen Sie die Trainingsschleife aus und bewerten Sie sie im Testsatz.
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
Nach 9 Epochen sollten wir 99,26 Testgenauigkeit erhalten. Wenn Sie es geschafft haben, herzlichen Glückwunsch!
Mit Jestimator können Sie Ihr Modell, das dem oben genannten MNIST -Beispiel hauptsächlich ähnelt, jedoch ohne Code für den Abschnitt "Haupt" erstellen. Der Jestimator dient als Einstiegspunkt für Ihr Modell, handelt automatisch in einem Zug-/Evalo-Once/Eval-Training-und-Save-the-Best/--Vorhersage-Modus und setzt Profilerstellung, Tensorboard und Protokollierung ein.
Darüber hinaus unterstützt Jestimator die Modellpartitionierung, die für das Training sehr großer Modelle in mehreren TPU -Pods erforderlich ist. Es unterstützt ein T5X-kompatibles Checkpoint-Format, das Kontrollpunkte auf verteilte Weise speichert und wiederherstellt, was für große Mehrfach-Pod-Modelle geeignet ist.
Um Modelle mit Jestimator auszuführen, müssen wir T5X und Flaxformer installieren:
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
Klonen Sie dann dieses Repo, um den Jestimatorcode zu erhalten:
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
Jetzt können wir ein lineares Regressionsmodell mit Spielzeug testen:
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
Wir geben dieses MNIST -Beispiel an, um zu demonstrieren, wie man Modellierungscode mit Jestimator schreibt. Es ist ähnlich wie das obige Beispiel, aber mit einem großen Vorteil, dass ein Konfigurationsobjekt umsetzt wird, um Informationen von globalen Flags und dem Datensatz zu sammeln, um die Modellierung dynamisch einzurichten. Dies erleichtert es, das Modell auf verschiedene Datensätze anzuwenden. Zum Beispiel kann man sofort die EMNIST- oder EuroSat-Datensätze ausprobieren, indem man einfach ein Befehlszeilenargument ändern, ohne den Code zu ändern.
Mit dem folgenden Befehl können wir einen Job für MNIST starten, alle 100 Schritte protokollieren und die Kontrollpunkte auf $ home/experimente/mnist/models speichern:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
In der Zwischenzeit können wir einen Job beginnen, um den Ordner $ Home/Experimente/MNIST/Models zu überwachen, am MNIST -Testsatz zu bewerten und das Modell mit der höchsten Genauigkeit zu speichern:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
Gleichzeitig können wir ein Tensorboard starten, um den Prozess zu überwachen:
tensorboard --logdir $HOME/experiments/mnist/models
Wir können den folgenden Befehl verwenden, um eine einzelne LSTM auf PTB zu trainieren:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
und bewerten:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
Es eignet sich zum Laufen auf einer Single-GPU-Maschine.
Hier finden Sie einige einfache Anleitungen zu den Bert-ähnlichen Modellen vor dem Training und zu Feinabstimmung unter Verwendung von TPUs auf der Google Cloud Platform (GCP). Man kann mit einem Webbrowser mit Null -Setup beginnen, indem man über Google Cloud Console eine Verbindung zu einem virtuellen Computer herstellt, ohne etwas lokal zu installieren. Wenn dies das erste Mal ist, wird man von genügend Credits abgedeckt, um die Befehle von frei zu probieren.