هذا ليس منتج Google المدعوم رسميًا.
هذا هو الكود المصدري للورقة "Amos: محسن على طراز آدم مع تسوس الوزن التكيفي نحو النطاق الموجهة نحو النموذج".
إنه ينفذ AMOS ، وهو مُحسّن متوافق مع مكتبة Optax ، و Jestimator ، وهي مكتبة خفيفة الوزن مع واجهة tf.Estimator مثل إدارة نقاط التفتيش المتوافقة مع T5X لبرامج التعلم الآلي في JAX ، والتي نستخدمها لتشغيل التجارب في الورقة.
pip install jestimator
سيقوم بتثبيت AMOS Optimizer المنفذ في Jestimator LIB.
يتم استخدام تطبيق AMOS هذا مع JAX ، وهي مكتبة حوسبة رقمية عالية الأداء مع تمايز تلقائي ، لأبحاث التعلم الآلي. API من AMOS متوافق مع Optax ، مكتبة من محسّلات Jax (نأمل أن يتم دمج AMOs في Optax في المستقبل القريب).
من أجل إظهار الاستخدام ، سنقوم بتطبيق Amos على Mnist. يعتمد على مثال MNIST الرسمي في Flax ، ويمكنك العثور على الكود في دفتر Jupyter هنا.
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
يحتفظ كائن TrainState بمعلمات النموذج وحالات المُحسّنة ، ويمكن أن يتم تفتيشه في الملفات.
نقوم بإنشاء النموذج والمحسّن في هذه الوظيفة.
للمحسن ، نستخدم Amos هنا. يتم تعيين المعلمة المفرطة التالية:
عادةً ما يتم ضبط معدل التعلم العالمي على 1/SQRT (N) ، حيث N هو عدد الدُفعات في بيانات التدريب. بالنسبة إلى MNIST ، لدينا 60K أمثلة تدريب وحجم الدُفعة هو 32. لذلك Learning_rate = 1/SQRT (60000/32).
يتطلب "ETA_FN" الخاص بالنموذج وظيفة ، بالنظر إلى اسم وشكل متغير ، بإرجاع تعويم يشير إلى المقياس المتوقع لهذا المتغير. نأمل في المستقبل القريب أن يكون لدينا مكتبات يمكنها حساب "eta_fn" تلقائيًا من رمز النمذجة ؛ ولكن في الوقت الحالي ، علينا تحديده يدويًا.
يمكن للمرء استخدام وظيفة المساعد AMOS_HELPER.PARAMS_FN_FROM_ASSIGN_MAP () لإنشاء "eta_fn" من issis_map. envision_map عبارة عنيل يعين قواعد regex إلى قيمة أو تعبير بيثون بسيط. سيجد قاعدة Regex الأولى التي تطابق اسم المتغير ، وتقييم تعبير Python إذا لزم الأمر لإرجاع القيمة. انظر مثالنا أدناه.
يتطلب "شكل _fn" بمثابة وظيفة ، بالنظر إلى اسم وشكل متغير ، بإرجاع شكل مخفض لمتغيرات الفتحة المقابلة. يمكننا استخدام وظيفة المساعد AMOS_HELPER.PARAMS_FN_FROM_ASSIGN_MAP () لإنشاء "شكل_fn" من issis_map أيضًا.
"بيتا" هو معدل الانحلال الأسي لمتوسط مربعات التدرج. قمنا بتعيينه على 0.98 هنا.
"Clip_value" هي قيمة القطع التدرج ، والتي يجب أن تتطابق مع حجم وظيفة الخسارة. إذا كانت وظيفة الخسارة عبارة عن مجموع من الإدخال المتقاطع ، فيجب علينا تعيين "clip_value" على SQRT لعدد الملصقات.
يرجى الرجوع إلى ورقتنا لمزيد من التفاصيل عن المعلميات المفرطة.
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
استخدم Jax's @JIT Decorator لتجميع الوظيفة في الوقت المناسب لأداء أفضل.
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
استخدم Jax's @JIT Decorator لتجميع الوظيفة في الوقت المناسب لأداء أفضل.
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
قم بتشغيل حلقة التدريب وتقييم مجموعة الاختبار.
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
بعد 9 عصر ، يجب أن نحصل على 99.26 دقة الاختبار. إذا قمت بذلك ، تهانينا!
مع Jestimator ، يمكنك إنشاء نموذجك في الغالب على مثال Mnist أعلاه ، ولكن دون كتابة رمز للقسم "الرئيسي" ؛ ستعمل Jestimator كنقطة دخول لنموذجك ، معالجة تلقائيًا في تحديد مقطوع القطار/التقييم/التقييم أثناء التدريب والأفضل/التنبؤ ، وإعداد التنميط والتنسيق والتسجيل.
بالإضافة إلى ذلك ، يدعم Jestimator تقسيم النماذج المطلوب لتدريب نماذج كبيرة جدًا عبر قرون TPU متعددة. وهو يدعم تنسيق نقطة تفتيش متوافق مع T5X يحفظ ويستعيد نقاط التفتيش بطريقة موزعة ، وهو مناسب لنماذج كبيرة متعددة البودات.
من أجل تشغيل النماذج باستخدام Jestimator ، نحتاج إلى تثبيت T5x و Flaxformer:
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
ثم ، استنساخ هذا الريبو للحصول على رمز Jestimator:
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
الآن ، يمكننا اختبار نموذج الانحدار الخطي لعبة:
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
نحن نقدم هذا المثال mnist لإظهار كيفية كتابة رمز النمذجة مع Jestimator. يشبه إلى حد كبير المثال أعلاه ، ولكن بميزة كبيرة ، يتم تمرير كائن التكوين لجمع المعلومات من الأعلام العالمية ومجموعة البيانات ، من أجل إعداد النمذجة ديناميكيًا. هذا يجعل من السهل تطبيق النموذج على مجموعات البيانات المختلفة ؛ على سبيل المثال ، يمكن للمرء على الفور تجربة مجموعات بيانات EMNIST أو EUROSAT ببساطة عن طريق تغيير وسيطة سطر الأوامر ، دون تعديل الرمز.
من خلال الأمر التالي ، يمكننا بدء وظيفة للتدريب على MNIST ، وتسجيل كل 100 خطوة ، وحفظ نقاط التفتيش إلى $ Home/Depariments/MNIST/Models:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
وفي الوقت نفسه ، يمكننا بدء وظيفة لمراقبة المجلد $ Home/Prepyents/MNIST/Models ، وتقييم مجموعة اختبار MNIST ، وحفظ النموذج بأعلى دقة:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
في الوقت نفسه ، يمكننا أن نبدأ من الترسور لمراقبة العملية:
tensorboard --logdir $HOME/experiments/mnist/models
يمكننا استخدام الأمر التالي لتدريب طبقة واحدة LSTM على PTB:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
وتقييم:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
إنه مناسب للتشغيل على آلة GPU الفردية.
فيما يلي بعض الأدلة البسيطة لنماذج ما قبل التدريب والبيتر ، باستخدام TPUs على منصة Google Cloud (GCP). يمكن للمرء أن يبدأ بمستعرض ويب مع إعداد صفري ، من خلال الاتصال بجهاز افتراضي عبر وحدة التحكم السحابية من Google ، دون تثبيت أي شيء محليًا. إذا كانت هذه هي المرة الأولى ، فسيتم تغطية واحدة مع ائتمانات كافية لتجربة الأوامر المجانية.