
ภาพรวม ติดตั้งอย่างรวดเร็ว | ผ้าลินินมีลักษณะอย่างไร? - เอกสาร
Flax NNX เปิดตัวในปี 2567 เป็น Flax API แบบง่ายใหม่ที่ออกแบบมาเพื่อให้ง่ายต่อการสร้างตรวจสอบการดีบักและวิเคราะห์เครือข่ายประสาทใน Jax มันประสบความสำเร็จโดยการเพิ่มการสนับสนุนชั้นหนึ่งสำหรับความหมายอ้างอิง Python สิ่งนี้ช่วยให้ผู้ใช้สามารถแสดงโมเดลของพวกเขาโดยใช้วัตถุ Python ปกติช่วยให้การแบ่งปันอ้างอิงและความไม่แน่นอน
Flax NNX วิวัฒนาการมาจาก Flax Linen API ซึ่งเปิดตัวในปี 2020 โดยวิศวกรและนักวิจัยที่ Google Brain ในการร่วมมืออย่างใกล้ชิดกับทีม JAX
คุณสามารถเรียนรู้เพิ่มเติมเกี่ยวกับ Flax NNX บนไซต์เอกสาร Flax เฉพาะ ตรวจสอบให้แน่ใจว่าคุณได้ตรวจสอบ:
หมายเหตุ: เอกสารของผ้าลินินลินินมีเว็บไซต์ของตัวเอง
ภารกิจของทีม Flax คือการให้บริการระบบนิเวศการวิจัยเครือข่าย Jax Neural ที่กำลังเติบโต - ทั้งในตัวอักษรและกับชุมชนที่กว้างขึ้นและเพื่อสำรวจกรณีการใช้งานที่ Jax ส่องแสง เราใช้ GitHub สำหรับการประสานงานและการวางแผนเกือบทั้งหมดของเรารวมถึงสถานที่ที่เราพูดถึงการเปลี่ยนแปลงการออกแบบที่กำลังจะมาถึง เรายินดีรับข้อเสนอแนะเกี่ยวกับการอภิปรายปัญหาและดึงเธรดคำขอ
คุณสามารถทำการร้องขอคุณสมบัติแจ้งให้เราทราบว่าคุณกำลังทำอะไรรายงานปัญหาถามคำถามในฟอรัมการสนทนา Flax GitHub ของเรา
เราคาดว่าจะปรับปรุงผ้าลินิน แต่เราไม่คาดหวังว่าจะมีการเปลี่ยนแปลงที่สำคัญใน Core API เราใช้รายการ Changelog และคำเตือนการเสื่อมราคาเมื่อเป็นไปได้
ในกรณีที่คุณต้องการติดต่อเราโดยตรงเราอยู่ที่ [email protected]
Flax เป็นห้องสมุดเครือข่ายประสาทประสิทธิภาพสูงและระบบนิเวศสำหรับ JAX ที่ ออกแบบมาเพื่อความยืดหยุ่น : ลองรูปแบบใหม่ของการฝึกอบรมโดยการหาตัวอย่างและโดยการปรับเปลี่ยนลูปการฝึกอบรมไม่เพิ่มคุณสมบัติลงในกรอบ
ผ้าลินินกำลังได้รับการพัฒนาอย่างใกล้ชิดกับทีม JAX และมาพร้อมกับทุกสิ่งที่คุณต้องการเพื่อเริ่มการวิจัยของคุณรวมถึง:
Neural Network API ( flax.nnx ): รวมถึง Linear , Conv , BatchNorm , LayerNorm , GroupNorm , ความสนใจ ( MultiHeadAttention ), LSTMCell , GRUCell , Dropout
ยูทิลิตี้และรูปแบบ : การฝึกอบรมการทำซ้ำการทำให้เป็นอนุกรมและการตรวจสอบ, การวัด, การดึงข้อมูลล่วงหน้าบนอุปกรณ์
ตัวอย่างการศึกษา : MNIST, การอนุมาน/การสุ่มตัวอย่างด้วยรูปแบบภาษา Gemma (หม้อแปลง), Transformer LM1B
Flax ใช้ JAX ดังนั้นตรวจสอบคำแนะนำการติดตั้ง JAX บน CPU, GPU และ TPU
คุณจะต้องมี Python 3.8 หรือใหม่กว่า ติดตั้ง Flax จาก PYPI:
pip install flax
ในการอัพเกรดเป็น Flax เวอร์ชันล่าสุดคุณสามารถใช้:
pip install --upgrade git+https://github.com/google/flax.git
ในการติดตั้งการพึ่งพาเพิ่มเติมบางอย่าง (เช่น matplotlib ) ที่จำเป็น แต่ไม่รวมอยู่ในการอ้างอิงบางอย่างคุณสามารถใช้:
pip install " flax[all] " เราให้สามตัวอย่างโดยใช้ Flax API: Perceptron แบบหลายชั้นอย่างง่าย, CNN และเครื่องเข้ารหัสอัตโนมัติ
หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับ Module ที่เป็นนามธรรมให้ตรวจสอบเอกสารของเราคำนำวงกว้างของเราไปยังโมดูลนามธรรม สำหรับการสาธิตที่เป็นรูปธรรมเพิ่มเติมเกี่ยวกับแนวปฏิบัติที่ดีที่สุดโปรดดูคำแนะนำและบันทึกของนักพัฒนาของเรา
ตัวอย่างของ MLP:
class MLP ( nnx . Module ):
def __init__ ( self , din : int , dmid : int , dout : int , * , rngs : nnx . Rngs ):
self . linear1 = Linear ( din , dmid , rngs = rngs )
self . dropout = nnx . Dropout ( rate = 0.1 , rngs = rngs )
self . bn = nnx . BatchNorm ( dmid , rngs = rngs )
self . linear2 = Linear ( dmid , dout , rngs = rngs )
def __call__ ( self , x : jax . Array ):
x = nnx . gelu ( self . dropout ( self . bn ( self . linear1 ( x ))))
return self . linear2 ( x )ตัวอย่างของ CNN:
class CNN ( nnx . Module ):
def __init__ ( self , * , rngs : nnx . Rngs ):
self . conv1 = nnx . Conv ( 1 , 32 , kernel_size = ( 3 , 3 ), rngs = rngs )
self . conv2 = nnx . Conv ( 32 , 64 , kernel_size = ( 3 , 3 ), rngs = rngs )
self . avg_pool = partial ( nnx . avg_pool , window_shape = ( 2 , 2 ), strides = ( 2 , 2 ))
self . linear1 = nnx . Linear ( 3136 , 256 , rngs = rngs )
self . linear2 = nnx . Linear ( 256 , 10 , rngs = rngs )
def __call__ ( self , x ):
x = self . avg_pool ( nnx . relu ( self . conv1 ( x )))
x = self . avg_pool ( nnx . relu ( self . conv2 ( x )))
x = x . reshape ( x . shape [ 0 ], - 1 ) # flatten
x = nnx . relu ( self . linear1 ( x ))
x = self . linear2 ( x )
return xตัวอย่างของ autoencoder:
Encoder = lambda rngs : nnx . Linear ( 2 , 10 , rngs = rngs )
Decoder = lambda rngs : nnx . Linear ( 10 , 2 , rngs = rngs )
class AutoEncoder ( nnx . Module ):
def __init__ ( self , rngs ):
self . encoder = Encoder ( rngs )
self . decoder = Decoder ( rngs )
def __call__ ( self , x ) -> jax . Array :
return self . decoder ( self . encoder ( x ))
def encode ( self , x ) -> jax . Array :
return self . encoder ( x )เพื่ออ้างถึงที่เก็บนี้:
@software{flax2020github,
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.10.2},
year = {2024},
}
ในรายการ BibTex ด้านบนชื่ออยู่ในลำดับตัวอักษรหมายเลขเวอร์ชันมีวัตถุประสงค์เพื่อให้ได้จาก Flax/vertion.py และปีสอดคล้องกับการเปิดตัวโอเพนซอร์ซของโครงการ
Flax เป็นโครงการโอเพ่นซอร์สที่ดูแลโดยทีมงานเฉพาะที่ Google DeepMind แต่ไม่ใช่ผลิตภัณฑ์ของ Google อย่างเป็นทางการ