Ini bukan produk Google yang didukung secara resmi.
Ini adalah kode sumber untuk makalah "Amos: An Adam-style Optimizer dengan pembusukan berat adaptif menuju skala yang berorientasi pada model".
Ini mengimplementasikan AMOS , pengoptimal yang kompatibel dengan pustaka Optax, dan Jestimator , perpustakaan ringan dengan antarmuka seperti tf.Estimator untuk mengelola pos pemeriksaan yang kompatibel dengan T5X untuk program pembelajaran mesin di JAX, yang kami gunakan untuk menjalankan eksperimen di kertas.
pip install jestimator
Ini akan menginstal pengoptimal AMOS yang diimplementasikan di Jestimator Lib.
Implementasi AMOS ini digunakan dengan JAX, perpustakaan komputasi numerik berkinerja tinggi dengan diferensiasi otomatis, untuk penelitian pembelajaran mesin. API AMOS kompatibel dengan Optax, perpustakaan pengoptimal JAX (semoga AMOS akan diintegrasikan ke dalam Optax dalam waktu dekat).
Untuk menunjukkan penggunaan, kami akan menerapkan AMOS ke MNIST. Ini didasarkan pada contoh Mnist resmi Flax, dan Anda dapat menemukan kode di buku catatan Jupyter di sini.
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
Objek TrainState menjaga parameter model dan status pengoptimal, dan dapat dicentang menjadi file.
Kami membuat model dan pengoptimal dalam fungsi ini.
Untuk pengoptimal, kami menggunakan AMOS di sini. Hyper-parameter berikut diatur:
Tingkat pembelajaran global biasanya diatur ke 1/sqrt (n), di mana n adalah jumlah batch dalam data pelatihan. Untuk MNIST, kami memiliki contoh pelatihan 60K dan ukuran batch adalah 32. Jadi learning_rate = 1/sqrt (60000/32).
'ETA_FN' khusus model membutuhkan fungsi yang, diberi nama dan bentuk variabel, mengembalikan float yang menunjukkan skala yang diharapkan dari variabel itu. Semoga dalam waktu dekat kita akan memiliki perpustakaan yang dapat secara otomatis menghitung 'ETA_FN' ini dari kode pemodelan; Tetapi untuk saat ini kita harus menentukannya secara manual.
Seseorang dapat menggunakan fungsi helper amos_helper.params_fn_from_assign_map () untuk membuat 'eta_fn' dari fecepat_map. Assign_map adalah dikt yang memetakan aturan Regex ke nilai atau ekspresi Python sederhana. Ini akan menemukan aturan regex pertama yang cocok dengan nama variabel, dan mengevaluasi ekspresi Python jika perlu untuk mengembalikan nilai. Lihat contoh kami di bawah ini.
'Shape_fn' juga membutuhkan fungsi yang, diberi nama dan bentuk variabel, mengembalikan bentuk yang dikurangi untuk variabel slot yang sesuai. Kita dapat menggunakan fungsi helper amos_helper.params_fn_from_assign_map () untuk membuat 'shape_fn' dari feceksion_map juga.
'Beta' adalah tingkat peluruhan eksponensial untuk menjalankan rata -rata kotak gradien. Kami mengaturnya ke 0,98 di sini.
'Clip_Value' adalah nilai kliping gradien, yang seharusnya sesuai dengan besarnya fungsi kerugian. Jika fungsi kerugian adalah jumlah cross-entropy, maka kita harus mengatur 'clip_value' ke SQRT dari jumlah label.
Silakan merujuk ke makalah kami untuk detail lebih lanjut tentang hyper-parameter.
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
Gunakan dekorator @jit Jax untuk mengompilasi fungsi untuk kinerja yang lebih baik.
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
Gunakan dekorator @jit Jax untuk mengompilasi fungsi untuk kinerja yang lebih baik.
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
Jalankan loop pelatihan dan evaluasi pada set tes.
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
Setelah 9 zaman, kita harus mendapatkan akurasi tes 99,26. Jika Anda berhasil, selamat!
Dengan Jestimator, Anda dapat membangun model Anda yang sebagian besar mirip dengan contoh MNIST di atas, tetapi tanpa menulis kode untuk bagian "utama"; Jestimator akan berfungsi sebagai titik masuk untuk model Anda, secara otomatis menangani pos pemeriksaan dalam mode kereta/eval-once/eval-while-dan-save-the-best/predict, dan mengatur profil, tenorboard, dan logging.
Selain itu, Jestimator mendukung partisi model yang diperlukan untuk melatih model yang sangat besar di beberapa pod TPU. Ini mendukung format pos pemeriksaan yang kompatibel dengan T5X yang menghemat dan mengembalikan pos pemeriksaan secara terdistribusi, yang cocok untuk model multi-pod besar.
Untuk menjalankan model dengan Jestimator, kita perlu menginstal T5X dan Flaxformer:
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
Kemudian, klon repo ini untuk mendapatkan kode Jestimator:
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
Sekarang, kita dapat menguji model regresi linier mainan:
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
Kami memberikan contoh Mnist ini untuk menunjukkan cara menulis kode pemodelan dengan Jestimator. Ini seperti contoh di atas, tetapi dengan keuntungan besar bahwa, objek konfigurasi dilewatkan untuk mengumpulkan informasi dari bendera global dan dataset, untuk mengatur pemodelan secara dinamis. Ini membuatnya lebih mudah untuk menerapkan model ke set data yang berbeda; Misalnya, seseorang dapat segera mencoba kumpulan data Emnist atau Eurosat hanya dengan mengubah argumen baris perintah, tanpa memodifikasi kode.
Dengan perintah berikut, kita dapat memulai pekerjaan untuk berlatih di MNIST, mencatat setiap 100 langkah, dan menyimpan pos pemeriksaan ke $ home/eksperimen/mnist/model:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
Sementara itu, kita dapat memulai pekerjaan untuk memantau $ home/eksperimen/folder Mnist/model, mengevaluasi pada set tes MNIST, dan menyimpan model dengan akurasi tertinggi:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
Pada saat yang sama, kita dapat memulai papan tensor untuk memantau prosesnya:
tensorboard --logdir $HOME/experiments/mnist/models
Kita dapat menggunakan perintah berikut untuk melatih satu lapisan LSTM di PTB:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
dan evaluasi:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
Sangat cocok untuk berjalan pada mesin gpu tunggal.
Berikut adalah beberapa panduan sederhana untuk pra-pelatihan dan menyempurnakan model seperti Bert, menggunakan TPU di Google Cloud Platform (GCP). Seseorang dapat mulai dengan browser web dengan pengaturan nol, dengan menghubungkan ke mesin virtual melalui Google Cloud Console, tanpa menginstal apa pun secara lokal. Jika ini adalah pertama kalinya, satu dicakup oleh kredit yang cukup untuk mencoba perintah secara gratis.