นี่ไม่ใช่ผลิตภัณฑ์ Google ที่ได้รับการสนับสนุนอย่างเป็นทางการ
นี่คือซอร์สโค้ดสำหรับกระดาษ "AMOS: เครื่องมือเพิ่มประสิทธิภาพสไตล์อดัมที่มีการสลายตัวของน้ำหนักแบบปรับได้สู่สเกลที่มุ่งเน้นแบบจำลอง"
มันใช้ AMOS ซึ่งเป็นเครื่องมือเพิ่มประสิทธิภาพที่เข้ากันได้กับไลบรารี Optax และ Jestimator ซึ่งเป็นไลบรารีน้ำหนักเบาที่มีอินเทอร์เฟซแบบ tf.Estimator -like เพื่อจัดการจุดตรวจสอบที่เข้ากันได้กับ T5X สำหรับโปรแกรมการเรียนรู้ของเครื่องใน JAX ซึ่งเราใช้เพื่อทำการทดลองในกระดาษ
pip install jestimator
มันจะติดตั้ง AMOS Optimizer ที่ใช้ใน Jestimator LIB
การใช้งานของ AMOS นี้ใช้กับ JAX ซึ่งเป็นไลบรารีการคำนวณเชิงตัวเลขที่มีประสิทธิภาพสูงพร้อมความแตกต่างอัตโนมัติสำหรับการวิจัยการเรียนรู้ของเครื่อง API ของ AMOS เข้ากันได้กับ Optax ซึ่งเป็นไลบรารีของ Jax Optimizers (หวังว่า AMOS จะรวมเข้ากับ Optax ในอนาคตอันใกล้)
เพื่อแสดงให้เห็นถึงการใช้งานเราจะใช้ AMOS กับ MNIST มันขึ้นอยู่กับตัวอย่าง MNIST อย่างเป็นทางการของ Flax และคุณสามารถค้นหารหัสในสมุดบันทึก Jupyter ได้ที่นี่
import jax
import jax.numpy as jnp # JAX NumPy
from jestimator import amos # The Amos optimizer implementation
from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
import math
import tensorflow_datasets as tfds # TFDS for MNIST
from sklearn.metrics import accuracy_score
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def classify_xe_loss(self, x, labels):
# Labels read from the tfds MNIST are integers from 0 to 9.
# Logits are arrays of size 10.
logits = self(x)
logits = jax.nn.log_softmax(logits)
labels_ = jnp.expand_dims(labels, -1)
llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
loss = -jnp.sum(llh_)
return loss
วัตถุ TrainState จะช่วยให้พารามิเตอร์แบบจำลองและสถานะเพิ่มประสิทธิภาพและสามารถตรวจสอบเป็นไฟล์ได้
เราสร้างโมเดลและเครื่องมือเพิ่มประสิทธิภาพในฟังก์ชั่นนี้
สำหรับเครื่องมือเพิ่มประสิทธิภาพเราใช้ AMO ที่นี่ มีการตั้งค่าพารามิเตอร์ไฮเปอร์ต่อไปนี้:
อัตราการเรียนรู้ทั่วโลกมักจะถูกตั้งค่าเป็น 1/sqrt (n) โดยที่ n คือจำนวนแบทช์ในข้อมูลการฝึกอบรม สำหรับ MNIST เรามีตัวอย่างการฝึกอบรม 60K และขนาดแบทช์คือ 32 ดังนั้นการเรียนรู้ _rate = 1/sqrt (60000/32)
โมเดลเฉพาะ 'ETA_FN' ต้องการฟังก์ชั่นที่ได้รับชื่อและรูปร่างตัวแปรส่งคืนโฟลตที่ระบุระดับที่คาดหวังของตัวแปรนั้น หวังว่าในอนาคตอันใกล้นี้เราจะมีห้องสมุดที่สามารถคำนวณ 'ETA_FN' ได้โดยอัตโนมัติจากรหัสการสร้างแบบจำลอง แต่ตอนนี้เราต้องระบุด้วยตนเอง
หนึ่งสามารถใช้ AMOS_HELPER.PARAMS_FN_FROM_ASSIGN_MAP () ฟังก์ชั่นตัวช่วยเพื่อสร้าง 'ETA_FN' จาก ASTEMENT_MAP Assign_map เป็น dict ที่แมปกฎ regex กับค่าหรือนิพจน์ Python อย่างง่าย มันจะพบกฎ regex แรกที่ตรงกับชื่อของตัวแปรและประเมินนิพจน์ Python หากจำเป็นต้องส่งคืนค่า ดูตัวอย่างของเราด้านล่าง
'shape_fn' ในทำนองเดียวกันต้องการฟังก์ชั่นที่ได้รับชื่อและรูปร่างตัวแปรส่งคืนรูปร่างที่ลดลงสำหรับตัวแปรสล็อตที่เกี่ยวข้อง เราสามารถใช้ AMOS_HELPER.PARAMS_FN_FROM_ASSIGN_MAP () ฟังก์ชั่นตัวช่วยเพื่อสร้าง 'shape_fn' จาก action_map เช่นกัน
'เบต้า' เป็นอัตราการสลายตัวแบบเอ็กซ์โปเนนเชียลสำหรับการทำงานเฉลี่ยของการไล่ระดับสีสี่เหลี่ยม เราตั้งค่าเป็น 0.98 ที่นี่
'CLIP_VALUE' คือค่าการตัดไล่ระดับสีซึ่งควรตรงกับขนาดของฟังก์ชันการสูญเสีย หากฟังก์ชั่นการสูญเสียเป็นผลรวมของข้าม-เอนโทรปีเราควรตั้งค่า 'clip_value' เป็น SQRT ของจำนวนฉลาก
โปรดดูเอกสารของเราสำหรับรายละเอียดเพิ่มเติมเกี่ยวกับพารามิเตอร์ไฮเปอร์
def get_train_state(rng):
model = CNN()
dummy_x = jnp.ones([1, 28, 28, 1])
params = model.init(rng, dummy_x)
eta_fn = amos_helper.params_fn_from_assign_map(
{
'.*/bias': 0.5,
'.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
'.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
'.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
'.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
},
eval_str_value=True,
)
shape_fn = amos_helper.params_fn_from_assign_map(
{
'.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
'.*Dense_0/kernel': '(1, SHAPE[1])',
'.*': (),
},
eval_str_value=True,
)
optimizer = amos.amos(
learning_rate=1/math.sqrt(60000/32),
eta_fn=eta_fn,
shape_fn=shape_fn,
beta=0.98,
clip_value=math.sqrt(32),
)
return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
ใช้ @Jit Doraction ของ Jax เพื่อรวบรวมฟังก์ชั่นเพื่อประสิทธิภาพที่ดีขึ้น
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
ใช้ @Jit Doraction ของ Jax เพื่อรวบรวมฟังก์ชั่นเพื่อประสิทธิภาพที่ดีขึ้น
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
เรียกใช้ลูปการฝึกอบรมและประเมินผลในชุดทดสอบ
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng # Must not be used anymore.
num_epochs = 9
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
perms = jax.random.permutation(input_rng, 60000)
del input_rng
perms = perms.reshape((60000 // 32, 32))
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state = train_step(batch, state)
pred = jax.device_get(infer_step(test_ds, state))
accuracy = accuracy_score(test_ds['label'], pred)
print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))
หลังจาก 9 ยุคเราควรได้รับความแม่นยำในการทดสอบ 99.26 ถ้าคุณทำมันขอแสดงความยินดี!
ด้วย jestimator คุณสามารถสร้างโมเดลของคุณส่วนใหญ่คล้ายกับตัวอย่าง MNIST ด้านบน แต่ไม่มีการเขียนรหัสสำหรับส่วน "หลัก"; Jestimator จะทำหน้าที่เป็นจุดเริ่มต้นสำหรับโมเดลของคุณจัดการกับจุดตรวจสอบโดยอัตโนมัติในโหมดรถไฟ/การประเมินผล/การประเมินผลในขณะที่การฝึกอบรมและการทำนายที่ดีที่สุด/ทำนายและตั้งค่าโปรไฟล์เทนซอร์บอร์ดและการบันทึก
นอกจากนี้ Jestimator ยังรองรับการแบ่งพาร์ติชันแบบจำลองซึ่งจำเป็นสำหรับการฝึกอบรมรุ่นที่มีขนาดใหญ่มากในหลาย ๆ TPU Pods รองรับรูปแบบจุดตรวจสอบที่เข้ากันได้กับ T5X ที่บันทึกและคืนค่าจุดตรวจในลักษณะกระจายซึ่งเหมาะสำหรับรุ่นหลายแบบหลายแบบ
ในการเรียกใช้โมเดลด้วย jestimator เราจำเป็นต้องติดตั้ง T5X และ Flaxformer:
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd ..
git clone --branch=main https://github.com/google/flaxformer
cd flaxformer
pip3 install .
cd ..
จากนั้นโคลน repo นี้เพื่อรับรหัส jestimator:
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
ตอนนี้เราสามารถทดสอบแบบจำลองการถดถอยเชิงเส้นของของเล่น:
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
เราให้ตัวอย่าง MNIST นี้เพื่อสาธิตวิธีการเขียนรหัสการสร้างแบบจำลองด้วย jestimator มันเป็นเหมือนตัวอย่างข้างต้น แต่มีข้อได้เปรียบอย่างมากที่วัตถุกำหนดค่าจะถูกส่งผ่านเพื่อรวบรวมข้อมูลจากธงทั่วโลกและชุดข้อมูลเพื่อการตั้งค่าการตั้งค่าแบบไดนามิก สิ่งนี้ทำให้ง่ายต่อการใช้โมเดลกับชุดข้อมูลที่แตกต่างกัน ตัวอย่างเช่นหนึ่งสามารถลองชุดข้อมูล Emnist หรือ Eurosat ได้ทันทีโดยการเปลี่ยนอาร์กิวเมนต์บรรทัดคำสั่งโดยไม่ต้องแก้ไขรหัส
ด้วยคำสั่งต่อไปนี้เราสามารถเริ่มงานเพื่อฝึกอบรม MNIST บันทึกทุก 100 ขั้นตอนและบันทึกจุดตรวจไว้ที่ $ home/Experiments/MNIST/รุ่น:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--train_pattern="tfds://mnist/split=train"
--model_dir="$HOME/experiments/mnist/models"
--train_batch_size=32
--train_shuffle_buf=4096
--train_epochs=9
--check_every_steps=100
--max_ckpt=20
--save_every_steps=1000
--module_config.warmup=2000
--module_config.amos_beta=0.98
ในขณะเดียวกันเราสามารถเริ่มงานเพื่อตรวจสอบโฟลเดอร์ $ Home/Experiments/MNIST/Models, ประเมินในชุดทดสอบ MNIST และบันทึกโมเดลด้วยความแม่นยำสูงสุด:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.mnist.mnist"
--module_config="jestimator/models/mnist/mnist.py"
--eval_pattern="tfds://mnist/split=test"
--model_dir="$HOME/experiments/mnist/models"
--eval_batch_size=32
--mode="eval_wait"
--check_ckpt_every_secs=1
--save_high="test_accuracy"
ในเวลาเดียวกันเราสามารถเริ่มต้นบอร์ดเพื่อตรวจสอบกระบวนการ:
tensorboard --logdir $HOME/experiments/mnist/models
เราสามารถใช้คำสั่งต่อไปนี้เพื่อฝึกเลเยอร์ LSTM เดียวบน PTB:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt"
--model_dir="$HOME/models/ptb_lstm"
--train_batch_size=64
--train_consecutive=113
--train_shuffle_buf=4096
--max_train_steps=200000
--check_every_steps=1000
--max_ckpt=20
--module_config.opt_config.optimizer="amos"
--module_config.opt_config.learning_rate=0.01
--module_config.opt_config.beta=0.98
--module_config.opt_config.momentum=0.0
และประเมิน:
PYTHONPATH=. python3 jestimator/estimator.py
--module_imp="jestimator.models.lstm.lm"
--module_config="jestimator/models/lstm/lm.py"
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt"
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt"
--model_dir="$HOME/models/ptb_lstm"
--eval_batch_size=1
เหมาะสำหรับการทำงานบนเครื่อง Single-GPU
ต่อไปนี้เป็นคำแนะนำง่ายๆสำหรับรุ่นก่อนรถไฟและปรับแต่งรูปแบบ Bert-like โดยใช้ TPUs บน Google Cloud Platform (GCP) หนึ่งสามารถเริ่มต้นด้วยเว็บเบราว์เซอร์ที่มีการตั้งค่าเป็นศูนย์โดยเชื่อมต่อกับเครื่องเสมือนผ่าน Google Cloud Console โดยไม่ต้องติดตั้งอะไรในเครื่อง หากนี่เป็นครั้งแรกที่หนึ่งจะถูกครอบคลุมโดยเครดิตเพียงพอที่จะลองใช้คำสั่งฟรี