これは、公式にサポートされているGoogle製品ではありません。
これは、「モデル指向のスケールに向けて適応体重減衰を備えたアダムスタイルのオプティマイザー」という論文のソースコードです。
Optaxライブラリと互換性のあるオプティマイザーであるAmosと、JAXの機械学習プログラムのT5X互換チェックポイントを管理するためのtf.Estimatorのようなインターフェイスを備えた軽量ライブラリであるJestimatorを実装します。
pip install jestimator
Jestimator libに実装されたAMOSオプティマイザーをインストールします。
AMOSのこの実装は、機械学習研究のために、自動分化を備えた高性能数値コンピューティングライブラリであるJaxで使用されます。 AMOSのAPIは、JAXオプティマイザーのライブラリであるOptaxと互換性があります(近い将来、AMOSがOptaxに統合されることを願っています)。
使用法を示すために、AMOSをMNISTに適用します。 Flaxの公式Mnistの例に基づいており、こちらのJupyterノートにコードを見つけることができます。
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
TrainStateオブジェクトは、モデルパラメーターとオプティマイザーの状態を保持し、ファイルにチェックポイントすることができます。
この関数でモデルとオプティマイザーを作成します。
オプティマイザーには、ここでAMOを使用します。次のハイパーパラメーターが設定されています。
グローバル学習率は通常、1/sqrt(n)に設定されます。ここで、nはトレーニングデータのバッチ数です。 MNISTの場合、60kのトレーニングの例とバッチサイズは32です。したがって、Learning_rate = 1/SQRT(60000/32)。
モデル固有の「ETA_FN」には、変数名と形状が与えられた場合、その変数の予想されるスケールを示すフロートを返す関数が必要です。近い将来、モデリングコードからこの「ETA_FN」を自動的に計算できるライブラリがあることを願っています。しかし今のところ、手動で指定する必要があります。
AMOS_HELPER.PARAMS_FN_FROM_ASSIGN_MAP()HELPER関数を使用して、assight_mapから「ETA_FN」を作成できます。 assight_mapは、regexルールを値または単純なpython式にマッピングするdictです。変数の名前と一致する最初の正規表現ルールが見つかり、必要に応じて値を返すためにPython式を評価します。以下の例を参照してください。
「Shape_fn」には、変数名と形状が与えられた場合、対応するスロット変数の形状が縮小される関数が必要です。 AMOS_HELPER.PARAMS_FN_FROM_ASSIGN_MAP()HELPER関数を使用して、Assight_Mapから「Shape_Fn」も作成できます。
「ベータ」とは、勾配正方形の平均を実行するための指数減衰率です。ここで0.98に設定します。
「Clip_Value」は勾配クリッピング値であり、損失関数の大きさと一致するはずです。損失関数がクロスエントロピーの合計である場合、ラベルの数のSQRTに「Clip_Value」を設定する必要があります。
ハイパーパラメーターの詳細については、私たちの論文を参照してください。
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
Jaxの@jitデコレーターを使用して、Just-in-timeコンパイルをパフォーマンスの向上にします。
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
Jaxの@jitデコレーターを使用して、Just-in-timeコンパイルをパフォーマンスの向上にします。
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
トレーニングループを実行し、テストセットで評価します。
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
9つの時代の後、99.26のテスト精度を取得する必要があります。あなたがそれを作ったなら、おめでとう!
Jestimatorを使用すると、モデルを上記のMnistの例とほぼ同様に構築できますが、「メイン」セクションのコードを記述することはできません。 Jestimatorは、モデルのエントリポイントとして機能し、トレイン/評価/評価/評価とベスト/予測モードでのチェックポイントを自動的に処理し、プロファイリング、テンソルボード、ロギングをセットアップします。
さらに、Jestimatorは、複数のTPUポッドにわたって非常に大きなモデルをトレーニングするために必要なモデルパーティション化をサポートしています。大規模なマルチポッドモデルに適した分散方法でチェックポイントを保存および復元するT5X互換のチェックポイント形式をサポートします。
Jestimatorでモデルを実行するには、T5XとFlaxformerをインストールする必要があります。
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
次に、このリポジトリをクローンして、ジェスティマーターコードを取得します。
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
これで、おもちゃの線形回帰モデルをテストできます。
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
このMnistの例を提供して、Jestimatorを使用してモデリングコードを作成する方法を示します。これは上記の例によく似ていますが、モデリングを動的にセットアップするために、グローバルフラグとデータセットから情報を収集するために構成オブジェクトが渡されるという大きな利点があります。これにより、モデルを異なるデータセットに簡単に適用できます。たとえば、コードを変更せずにコマンドライン引数を変更するだけで、すぐにエミストまたはユーロサットデータセットを試すことができます。
次のコマンドを使用すると、MNISTでトレーニングし、100ステップごとにログを記録し、チェックポイントを$ HOME/EXPERIMENTS/MNIST/モデルに保存するためのジョブを開始できます。
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
一方、$ HOME/EXPERIMENTS/MNIST/MODELEフォルダーを監視するためのジョブを開始し、MNISTテストセットで評価し、最も高い精度でモデルを保存できます。
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
同時に、テンソルボードを起動してプロセスを監視できます。
tensorboard --logdir $HOME/experiments/mnist/models
次のコマンドを使用して、PTBで単一層LSTMをトレーニングできます。
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
評価:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
シングルGPUマシンでの実行に適しています。
Google Cloud Platform(GCP)でTPUを使用して、Bertのようなモデルを事前に微調整し、微調整するための簡単なガイドを以下に示します。 Google Cloud Consoleを介して仮想マシンに接続して、ローカルに何もインストールせずに、セットアップがゼロのWebブラウザーから始めることができます。これが初めての場合、コマンドを無料で試すのに十分なクレジットでカバーされます。