这不是官方支持的Google产品。
这是“ AMO:AMOS:ADAM-Style优化器具有自适应重量衰减对模型量表”的源代码。
它实现了与Optax库兼容的AMOS和带有tf.Estimator -like接口的轻量级库Jestimator ,用于管理JAX中机器学习程序的T5X兼容检查点,我们用来在论文中运行实验。
pip install jestimator
它将安装Jestimator LIB中实现的AMOS优化器。
AMOS的实现与JAX一起使用,Jax是一个具有自动分化的高性能数值计算库,用于机器学习研究。 AMOS的API与JAX优化器库Optax兼容(希望在不久的将来将AMOS集成到Optax中)。
为了证明使用情况,我们将向MNIST应用AMO。它基于亚麻的官方MNIST示例,您可以在此处的Jupyter笔记本中找到代码。
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
TrainState对象保留模型参数和优化器状态,并且可以检查到文件中。
我们在此功能中创建模型和优化器。
对于优化器,我们在此处使用AMO。设置了以下超参数:
全球学习率通常设置为1/sqrt(n),其中n是培训数据中的批次数。对于MNIST,我们有60k的培训示例,批次大小为32。因此,Learning_rate = 1/sqrt(60000/32)。
特定于模型的“ eta_fn”需要一个函数,鉴于变量名称和形状,返回浮点,指示该变量的预期尺度。希望在不久的将来,我们将拥有可以自动从建模代码中计算出此“ ETA_FN”的库;但是现在我们必须手动指定它。
一个人可以使用amos_helper.params_fn_from_assign_map()助手函数从agiss_map创建'eta_fn'。 agiss_map是一个dict,将其映射到一个值或简单的python表达式。它将找到与变量名称相匹配的第一个正则规则,并在必要时评估python表达式以返回值。请参阅下面的示例。
同样,“ shape_fn”需要一个函数,在给定变量名称和形状的情况下,返回相应的插槽变量的降低形状。我们可以使用Amos_helper.params_fn_from_assign_map()辅助功能来从agiss_map创建'shape_fn'。
“ beta”是梯度正方形平均运行平均值的指数衰减率。我们将其设置为0.98。
'clip_value'是梯度剪辑值,应与损耗函数的幅度匹配。如果损耗函数是交叉凝集的总和,则应将“ clip_value”设置为标签数量的SQRT。
请参阅我们的论文以获取更多参数的详细信息。
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
使用JAX的@Jit Decorator来及时编译该功能以提高性能。
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
使用JAX的@Jit Decorator来及时编译该功能以提高性能。
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
运行训练环并在测试集上进行评估。
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
在9个时期之后,我们应该获得99.26的测试准确性。如果您做到了,恭喜!
使用Jestimator,您可以构建与上面的MNIST示例相似的模型,但没有为“主”部分编写代码; Jestimator将作为您的模型的入口点,自动在火车/评估/评估训练和训练中最佳/预测模式中自动处理检查点,并设置分析,张紧板和日志记录。
此外,Jestimator还支持模型分区,这是跨多个TPU POD训练非常大型模型所必需的。它支持与T5X兼容的检查点格式,该格式以分布式方式保存和恢复检查点,该格式适用于大型多POD模型。
为了使用Jestimator运行模型,我们需要安装T5X和FlaxFormer:
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
然后,克隆此存储库以获取Jestimator代码:
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
现在,我们可以测试玩具线性回归模型:
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
我们提供此MNIST示例,以演示如何使用Jestimator编写建模代码。这与上面的示例非常相似,但是具有很大的优势,即传递配置对象以从全局标志和数据集收集信息,以动态设置建模。这使得将模型应用于不同数据集变得更加容易;例如,只需更改命令行参数而无需修改代码,就可以立即尝试通过更改命令行参数来尝试EMNIST或EUROSAT数据集。
有了以下命令,我们可以启动一项工作以训练MNIST,每100个步骤记录每100个步骤,然后将检查点保存到$ HOME/实验/MNIST/模型:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
同时,我们可以启动一项工作以监视$ HOME/实验/MNIST/MODELS文件夹,对MNIST测试集进行评估,并以最高的精度保存模型:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
同时,我们可以启动一个张板来监视该过程:
tensorboard --logdir $HOME/experiments/mnist/models
我们可以使用以下命令在PTB上训练单层LSTM:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
并评估:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
它适合在单GPU机器上运行。
以下是使用Google Cloud Platform上的TPU(GCP)上的TPU预先培训和微调BERT样模型的一些简单指南。一个人可以从具有零设置的Web浏览器开始,通过Google Cloud Console连接到虚拟机,而无需在本地安装任何内容。如果这是第一次,则有足够的信用来免费尝试命令。