這不是官方支持的Google產品。
這是“ AMO:AMOS:ADAM-Style優化器具有自適應重量衰減對模型量表”的源代碼。
它實現了與Optax庫兼容的AMOS和帶有tf.Estimator -like接口的輕量級庫Jestimator ,用於管理JAX中機器學習程序的T5X兼容檢查點,我們用來在論文中運行實驗。
pip install jestimator
它將安裝Jestimator LIB中實現的AMOS優化器。
AMOS的實現與JAX一起使用,Jax是一個具有自動分化的高性能數值計算庫,用於機器學習研究。 AMOS的API與JAX優化器庫Optax兼容(希望在不久的將來將AMOS集成到Optax中)。
為了證明使用情況,我們將向MNIST應用AMO。它基於亞麻的官方MNIST示例,您可以在此處的Jupyter筆記本中找到代碼。
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
TrainState對象保留模型參數和優化器狀態,並且可以檢查到文件中。
我們在此功能中創建模型和優化器。
對於優化器,我們在此處使用AMO。設置了以下超參數:
全球學習率通常設置為1/sqrt(n),其中n是培訓數據中的批次數。對於MNIST,我們有60k的培訓示例,批次大小為32。因此,Learning_rate = 1/sqrt(60000/32)。
特定於模型的“ eta_fn”需要一個函數,鑑於變量名稱和形狀,返回浮點,指示該變量的預期尺度。希望在不久的將來,我們將擁有可以自動從建模代碼中計算出此“ ETA_FN”的庫;但是現在我們必須手動指定它。
一個人可以使用amos_helper.params_fn_from_assign_map()助手函數從agiss_map創建'eta_fn'。 agiss_map是一個dict,將其映射到一個值或簡單的python表達式。它將找到與變量名稱相匹配的第一個正則規則,並在必要時評估python表達式以返回值。請參閱下面的示例。
同樣,“ shape_fn”需要一個函數,在給定變量名稱和形狀的情況下,返回相應的插槽變量的降低形狀。我們可以使用Amos_helper.params_fn_from_assign_map()輔助功能來從agiss_map創建'shape_fn'。
“ beta”是梯度正方形平均運行平均值的指數衰減率。我們將其設置為0.98。
'clip_value'是梯度剪輯值,應與損耗函數的幅度匹配。如果損耗函數是交叉凝集的總和,則應將“ clip_value”設置為標籤數量的SQRT。
請參閱我們的論文以獲取更多參數的詳細信息。
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
使用JAX的@Jit Decorator來及時編譯該功能以提高性能。
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
使用JAX的@Jit Decorator來及時編譯該功能以提高性能。
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
運行訓練環並在測試集上進行評估。
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
在9個時期之後,我們應該獲得99.26的測試準確性。如果您做到了,恭喜!
使用Jestimator,您可以構建與上面的MNIST示例相似的模型,但沒有為“主”部分編寫代碼; Jestimator將作為您的模型的入口點,自動在火車/評估/評估訓練和訓練中最佳/預測模式中自動處理檢查點,並設置分析,張緊板和日誌記錄。
此外,Jestimator還支持模型分區,這是跨多個TPU POD訓練非常大型模型所必需的。它支持與T5X兼容的檢查點格式,該格式以分佈式方式保存和恢復檢查點,該格式適用於大型多POD模型。
為了使用Jestimator運行模型,我們需要安裝T5X和FlaxFormer:
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
然後,克隆此存儲庫以獲取Jestimator代碼:
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
現在,我們可以測試玩具線性回歸模型:
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
我們提供此MNIST示例,以演示如何使用Jestimator編寫建模代碼。這與上面的示例非常相似,但是具有很大的優勢,即傳遞配置對像以從全局標誌和數據集收集信息,以動態設置建模。這使得將模型應用於不同數據集變得更加容易;例如,只需更改命令行參數而無需修改代碼,就可以立即嘗試通過更改命令行參數來嘗試EMNIST或EUROSAT數據集。
有了以下命令,我們可以啟動一項工作以訓練MNIST,每100個步驟記錄每100個步驟,然後將檢查點保存到$ HOME/實驗/MNIST/模型:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
同時,我們可以啟動一項工作以監視$ HOME/實驗/MNIST/MODELS文件夾,對MNIST測試集進行評估,並以最高的精度保存模型:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
同時,我們可以啟動一個張板來監視該過程:
tensorboard --logdir $HOME/experiments/mnist/models
我們可以使用以下命令在PTB上訓練單層LSTM:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
並評估:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
它適合在單GPU機器上運行。
以下是使用Google Cloud Platform上的TPU(GCP)上的TPU預先培訓和微調BERT樣模型的一些簡單指南。一個人可以從具有零設置的Web瀏覽器開始,通過Google Cloud Console連接到虛擬機,而無需在本地安裝任何內容。如果這是第一次,則有足夠的信用來免費嘗試命令。