Это не официально поддерживаемый продукт Google.
Это исходный код для статьи «Amos: оптимизатор в стиле Adam с адаптивным распадом веса в направлении масштаба ориентированной на модель».
Он реализует Amos , оптимизатор, совместимый с библиотекой Optax, и Jestimator , легкой библиотекой с интерфейсом, похожим на tf.Estimator , для управления T5X-совместимыми контрольно-пропускными пунктами для программ машинного обучения в JAX, который мы используем для экспериментов в статье.
pip install jestimator
Он установит оптимизатор AMOS, реализованный в Jestimator Lib.
Эта реализация AMOS используется с JAX, высокоэффективной численной вычислительной библиотекой с автоматической дифференциацией для исследований машинного обучения. API AMOS совместим с Optax, библиотекой оптимизаторов JAX (надеюсь, Amos будет интегрирован в Optax в ближайшем будущем).
Чтобы продемонстрировать использование, мы применим Amos к Mnist. Он основан на официальном примере MNIST от льна, и вы можете найти код в ноутбуке Юпитера здесь.
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
Объект TrainState сохраняет параметры модели и состояния оптимизатора и может быть покрыт в файлы.
Мы создаем модель и оптимизатор в этой функции.
Для оптимизатора мы используем Amos здесь. Следующие гиперпараметры установлены:
Глобальный уровень обучения обычно устанавливается на 1/SQRT (n), где n - количество партий в данных обучения. Для MNIST у нас есть примеры обучения 60K, а размер партии составляет 32. Таким образом, Learning_Rate = 1/SQRT (60000/32).
Специфичная для модели 'eta_fn' требует функции, которая, учитывая имя и форму переменной, возвращает плавание, указывающее на ожидаемую шкалу этой переменной. Надеемся, что в ближайшем будущем у нас будут библиотеки, которые могут автоматически рассчитать это «eta_fn» из кода моделирования; Но сейчас мы должны указать это вручную.
Можно использовать вспомогательную функцию amos_helper.params_fn_from_assign_map () для создания «eta_fn» из назначения_map. Assign_map - это дикт, который отображает правила Regex на значение или простое выражение Python. Он найдет первое правило режима, которое соответствует имени переменной, и оценит выражение Python, если это необходимо, чтобы вернуть значение. Смотрите наш пример ниже.
Аналогичным образом требуется функция, которая, учитывая имя и форму переменной, возвращает уменьшенную форму для соответствующих переменных слота. Мы можем использовать вспомогательную функцию AMOS_HELPER.PARAMS_FN_FROM_ASSIGN_MAP () для создания 'SHAME_FN' также из назначения_Мапа.
«Бета» - это экспоненциальная скорость распада для среднего уровня для пробега градиентных квадратов. Мы установили его на 0,98 здесь.
'clip_value' - это значение градиентного отсечения, которое должно соответствовать величине функции потери. Если функция потери является суммой перекрестной энтропии, то мы должны установить «clip_value» на SQRT количества метков.
Пожалуйста, обратитесь к нашей статье для получения более подробной информации о гиперпараметрах.
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
Используйте Jax's @jit Decorator, чтобы просто введите время для лучшей производительности.
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
Используйте Jax's @jit Decorator, чтобы просто введите время для лучшей производительности.
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
Запустите петлю обучения и оцените тестовый набор.
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
После 9 эпох мы должны получить точность тестирования 99,26. Если вы сделали это, поздравляю!
С помощью JESTIMATOR вы можете построить свою модель, в основном похожей на пример MNIST выше, но без написания кода для разделения «Основной»; JESTIMATOR будет служить точкой входа для вашей модели, автоматически обрабатывает контрольно-пропускной пункт в режиме Train/Eval-Once/Eval-While Training and Save The Save The Say/прогноз, а также настраивает профилирование, тензоры и регистрацию.
Кроме того, Jestimator поддерживает модельное разделение, которое необходимо для обучения очень больших моделей для нескольких стручков TPU. Он поддерживает совместимый с T5X форматом контрольной точки, который сохраняет и восстанавливает контрольные точки распределенным образом, что подходит для крупных моделей с несколькими лаком.
Чтобы запустить модели с Jestimator, нам нужно установить T5X и FlaxFormer:
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
Затем клонируйте этот репо, чтобы получить код Jestimator:
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
Теперь мы можем проверить модель линейной регрессии игрушки:
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
Мы приводим этот пример MNIST, чтобы продемонстрировать, как писать код моделирования с помощью JESTIMATOR. Это очень похоже на пример выше, но с большим преимуществом, что объект конфигурации передается для сбора информации из глобальных флагов и набора данных, для динамической настройки моделирования. Это облегчает применение модели к различным наборам данных; Например, можно сразу же попробовать наборы данных Emnist или Eurosat, просто изменив аргумент командной строки, без изменения кода.
С помощью следующей команды мы можем начать работу по обучению на MNIST, войти в систему каждые 100 шагов и сохранить контрольно -пропускные пункты в $ home/experiments/mnist/models:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
Между тем, мы можем начать работу по мониторингу папки $ Home/Experiments/MNIST/Models, оценить набор тестового набора MNIST и сохранить модель с самой высокой точностью:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
В то же время мы можем запустить тензоры для мониторинга процесса:
tensorboard --logdir $HOME/experiments/mnist/models
Мы можем использовать следующую команду для обучения одного слоя LSTM на PTB:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
и оценить:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
Он подходит для работы на машине с одним-GPU.
Вот несколько простых направляющих для моделей до тренировок и тонкой настройки BERT, использующих TPU на платформе Google Cloud (GCP). Можно начать с веб -браузера с нулевой настройкой, подключившись к виртуальной машине через Cloud Console Google, не устанавливая ничего локально. Если это первый раз, один охватывается достаточным количеством кредитов, чтобы попробовать команды бесплатно.