
注意力Grid是由Agora帶給您的,我們是一個全新的開源多模式AI研究組織,致力於推進人類。
加入我們這裡,為該項目做出貢獻或獲得支持!

注意力Grid是一個尖端的框架,旨在將高級註意機制的融合到AI模型中。注意力基於基於注意力的變壓器模型的最新發展,注意力Grid向機器學習從業人員,研究人員和愛好者開闢了關注機制的世界。
要用注意格爆炸,請使用PIP安裝包裝:
pip install AttentionGrid實施注意機製或帶有註意格的變壓器模型很容易:
from AttentionGrid import BlockwiseParallelJax
import jax . numpy as jnp
# Initialize the class
bpjax = BlockwiseParallelJax (
q_chunk_size = 64 ,
k_chunk_size = 64 ,
hidden_size = 768 ,
num_heads = 12 ,
rotary_dim = 32 ,
intermediate_size = 3072
)
# Suppose we have hidden_states, attention_mask, and position_ids as input data
hidden_states = jnp . random . rand ( 1 , 100 , 768 )
attention_mask = jnp . random . rand ( 1 , 1 , 100 , 100 )
position_ids = jnp . arange ( 100 ). reshape ( 1 , 100 )
# You can now apply the attention mechanism to your input data
output = bpjax . forward ( hidden_states , attention_mask , position_ids )我們鼓勵您與您的社區分享注意力!以下是幾個社交媒體平台的快速共享鏈接:
在Twitter上分享
分享LinkedIn
在Facebook上分享
分享Reddit
分享WhatsApp
感謝您支持注意力grig,並為人工智能的民主化做出貢獻!一起,我們可以突破可能的界限。
在AI的廣闊景觀中,注意機制徹底改變了我們創建強大模型的能力,可以辨別數據中的微妙之處,重點關注重要方面並改善整體性能。我們使用注意力網格的願景是彌合這些最新機制與其實際應用之間的差距,提供了一種工具,使這些技術在不同的AI應用中易於訪問且易於實現。
注意力Grid設計具有直觀且靈活的架構,分為四個主要組成部分:
核心:這是我們框架的基石,是集中抽像類,它們為註意機制和變壓器模型的基本結構佈局。
注意力? :專門針對各種注意機制的目錄。每個注意機制都是根據核心中提供的藍圖實現的。
變壓器? :這是變壓器模型的生命,每個模型都按照核心中定義的設計雕刻。
UTILS :一個裝有助手類的工具箱,用於基本任務,例如模型加載,數據預處理等。
示例:通過動手實例和用法方案揭開實施的神秘面紗。
模塊化結構:將不同的注意機制與多種變壓器模型匹配。
用戶友好:清晰的文檔和示例,以幫助您快速入門。
開源:對貢獻開放,注意力網格在集體知識和共同的進步上蓬勃發展。
有關更詳細的示例,請參閱我們存儲庫中的“示例”文件夾。
我們公開邀請人們對注意力的貢獻!無論您有新功能建議,錯誤報告還是要添加到我們的代碼中,請隨時打開問題或提交拉動請求。
注意力Grid是由Apache許可證獲得許可的開源軟件。
注意機制已轉化為AI,使機器可以“專注”輸入數據的重要部分。通過注意力網絡,我們旨在民主化對這些強大工具的訪問。我們認為,人工智能的未來在於關注的力量,並且通過注意力,我們希望加速這一旅程。探索我們的存儲庫,加入我們的事業,讓我們一起瀏覽這一激動人心的景觀!
“細節不是細節。它們進行了設計。” - 查爾斯·埃姆斯(Charles Eames)
整合閃光的注意力和變體
整合具有里程碑意義的關注
整合塊平行注意
整合動態稀疏閃光注意力
整合來自ImageBind的交叉注意
整合柯爾特5的注意力
整合多奇的關注
整合來自清醒降雨的包裝紙X_TransFormers,解碼器,注意力,編碼器,變壓器包裝器
| 機制 | 調用方法 | 示例導入 |
|---|---|---|
| 自我注意力 | from AttentionGrid import SelfAttention | from AttentionGrid import SelfAttention |
| 全球關注 | from AttentionGrid import GlobalAttention | from AttentionGrid import GlobalAttention |
| 當地的關注 | from AttentionGrid import LocalAttention | from AttentionGrid import LocalAttention |
| 分層的關注 | from AttentionGrid import HierarchicalAttention | from AttentionGrid import HierarchicalAttention |
| 動態稀疏的注意力 | from AttentionGrid import dynamic_sparse_attention | from AttentionGrid import dynamic_sparse_attention |
| 緊湊功能 | from AttentionGrid import compact | from AttentionGrid import compact |
| 墊索引功能 | from AttentionGrid import pad_index | from AttentionGrid import pad_index |
| 塊平行的注意 | from AttentionGrid import BlockwiseParallelJax | from AttentionGrid import BlockwiseParallelJax |
| 閃爍注意力 | from AttentionGrid import FlashAttention | from AttentionGrid import FlashAttention |
| 具有里程碑意義的關注 | from AttentionGrid import LandmarkAttention | from AttentionGrid import LandmarkAttention |
| 柯爾特-5注意 | from AttentionGrid import Colt5Attention | from AttentionGrid import Colt5Attention |
| 多傳奇的關注 | from AttentionGrid import MultiQueryAttention | from AttentionGrid import MultiQueryAttention |
| 注意力擴張 | from AttentionGrid import DilatedAttention | from AttentionGrid import DilatedAttention |
Agora的dynamic_sparse_attention功能允許在Hash-Sparse實現和QK-SPARSE實現之間進行選擇。該函數的目標是基於所選的sparsity_mode指導稀疏的注意機制。
功能參數如下:
q :查詢形狀的張量(批次,n_ctx_q,h,d_head)k :形狀的關鍵張量(批次,n_ctx_kv,h,d_head)v :形狀的值張量(批次,n_ctx_kv,h,d_head)q_idx & k_idx :如果Sparsity_mode為'Hash',則表示存儲率索引,或者如果Sparsity_mode為'qk',是否要保持給定頭。張量形狀分別為(批處理,n_ctx_q,h)和(批處理,n_ctx_kv,h)。sm_scale :歸一化常數,1/sqrt(d_head),除非指定。sparsity_mode :'HASH'用於為QK-SPARSE實現選擇Hash-Sparse實現和“ QK”。默認情況下, sm_scale的計算是默認情況下的,如果給出了未知的sparsity_mode ,則會拋出鍵盤。
然後,該函數檢查sparsity_mode並根據其值,它調用hash_sparse_attention或qk_sparse_attention 。
compact函數使用keep_tensor的信息構建了輸入張量x的緊湊表示。
功能參數為:
x :帶有形狀的輸入張量(批處理,n_ctx,h,d_head)。keep_tensor :float張量(批處理,n_ctx,h),該張量在保持頭部時包含1,否則為0。該函數首先計算indices_per_head ,該Indices_per_head計算每個頭部非殺菌元素的數量。它在保留均等元素的順序(stable = true)的同時,按降序對keep_tensor進行分類。然後,它根據索引張量收集x的元素。結果是x的緊湊表示以及索引張量和代表每個頭部非殺菌元件數量的張量。
pad_index函數將索引張量填充以符合內核。它採用以下參數:
index :由compact給出的原始索引張量,帶有形狀(批處理,buffer_size,h)。對於每個批次和時間步,它代表其源自的頭部索引。indices_per_head :對於每個頭部,包含尚未刪除多少個索引。它創建索引張量的副本,並根據indices_per_head的大小創建掩碼。然後,它修改了副本中與蒙版中對應的索引,以等於pad_idx 。
qk_sparse_attention功能是動態稀疏注意機制的一部分。當sparsity_mode設置為'qk'時,使用它。此功能實現了QK-SPARSE注意機制,並要求q_keep和k_keep參數為float類型。
它首先使用compact函數構建查詢,密鑰和值張量的緊湊表示。然後,它使用pad_index函數對索引張量進行填充。然後將張量轉移以與核的兼容性。最後,該功能調用qk_sparse_attention_kernel功能,並將結果張量分散回原始尺寸空間。
hash_sparse_attention函數是動態稀疏注意機制的一部分。當sparsity_mode設置為“哈希”時,使用它。該功能實現了哈希·帕斯斯注意機制。
該函數採用與qk_sparse_attention相同的輸入參數。但是, hash_sparse_attention函數而不是q_keep和k_keep參數,需要q_bucket_idx和k_bucket_idx ,該功能分別代表查詢和鍵的存儲桶索引。
hash_sparse_attention函數首先使用sort_bucketed_attention函數根據存儲鍵指數對查詢,鍵和值張量進行分類。然後,它使用compact函數構建了分類查詢,鍵和值張量的緊湊表示。然後,它使用pad_index函數對索引張量進行填充。
然後將張量轉移以與核的兼容性。然後,該功能調用hash_sparse_attention_kernel函數,並將結果張量分散回原始尺寸空間。
sort_bucketed_attention函數是hash_sparse_attention中使用的輔助功能。它根據給定的存儲桶指數對輸入張量進行分類。
功能參數為:
qkv :查詢,鍵,值張量(批次,n_ctx,h,d_head)qkv_bucket_idx :查詢,鍵和形狀值的存儲率索引(批次,n_ctx,h)該功能首先對qkv_bucket_idx張量進行分類,並獲取排序的索引。然後,它使用排序的索引對qkv張量進行分組。它還將qkv_bucket_idx擴展到與qkv相同的形狀,以兼容。
qk_sparse_attention_kernel函數是qk_sparse_attention中使用的內核函數。它根據查詢和關鍵產品的軟效果計算了價值的加權總和。
功能參數為:
q :查詢形狀的張量(批次,n_ctx_q,h,d_head)k :形狀的關鍵張量(批次,n_ctx_kv,h,d_head)v :形狀的值張量(批次,n_ctx_kv,h,d_head)sm_scale :歸一化常數,1/sqrt(d_head),除非指定。hash_sparse_attention_kernel函數是hash_sparse_attention中使用的內核函數。它的工作原理與qk_sparse_attention_kernel類似,但可以處理Hash-Sparse注意的桶。
功能參數與qk_sparse_attention_kernel的函數參數相同。但是, q , k和v已根據存儲桶指數進行了分類和壓實。
內核計算查詢和鍵的乘積,按sm_scale縮放它,應用SoftMax來獲得權重,然後計算值的加權總和。
請注意,這是對文檔的一般解釋,在實踐中理解和修改這些功能可能需要深入了解稀疏注意機制和深度學習原理。
blockwise_compute_attn函數:
blockwise_compute_attn函數是BlockwiseParallelJax類的重要組成部分,用於以模塊的方式計算模型的注意機制。
參數:
query , key , value :這些參數分別代表查詢,鍵和值的主要輸入。bias :可選參數,用於在軟磁性之前為註意力分數增加偏差。deterministic :用於決定是否應用輟學的布爾標誌。dropout_rng :用於輟學的隨機數生成器。attn_pdrop :輟學的可能性。causal_mask :是否使用因果關注面具,是布爾的標誌。query_chunk_size , key_chunk_size :每個查詢和密鑰塊的大小分別。dtype :計算的數據類型。默認值是jnp.float32 。policy :此參數定義了梯度檢查點的策略。precision :此參數用於設置計算的精度級別。默認值是lax.Precision.HIGHEST 。prevent_cse :用於防止常見子表達消除的布爾標誌。blockwise_compute_ffn函數:
blockwise_compute_ffn函數用於以模塊的方式計算模型的饋送網絡。
參數:
cell :應用函數的網絡中的單元格。inputs :饋送網絡的輸入數據。chunk_size :每個塊的大小用於塊計算。deterministic :用於決定是否應用輟學的布爾標誌。policy :此參數定義了梯度檢查點的策略。prevent_cse :用於防止常見子表達消除的布爾標誌。blockwise_lm_head類:
Blockwise_LM_Head類是應用線性變換的模塊,然後使用軟磁性函數來在輸入中每個位置的詞彙上產生分佈。
vocab_size :詞彙的大小,這也是線性轉換的輸出維度的大小。chunk_size :每個塊的大小用於塊計算。policy :此參數定義了梯度檢查點的策略。dtype :計算的數據類型。默認值是jnp.float32 。prevent_cse :用於防止常見子表達消除的布爾標誌。blockwise_cross_entropy函數:
blockwise_cross_entropy函數以模型方式計算模型預測的跨凝結損失。
參數:
logits :模型的輸出預測。tokens :真正的標籤。valid :指定輸入中有效位置的掩碼。chunk_size :每個塊的大小用於塊計算。policy :此參數定義了梯度檢查點的策略。prevent_cse :用於防止常見子表達消除的布爾標誌。BlockWiseParallelJax類:
BlockwiseParallelJax ( q_chunk_size , k_chunk_size , hidden_size , num_heads , rotary_dim , intermediate_size , layer_norm_epsilon = 1e-5 , activation_function = "gelu" , attn_pdrop = 0.0 , resid_pdrop = 0.0 , max_position_embeddings = 1024 , dtype = jnp . float32 , causal = True , policy = 'nothing_saveable' , prevent_cse = False , float32_logits = False )參數
q_chunk_size :整數。自我注意的查詢的塊大小。k_chunk_size :整數。自我注意的鑰匙的大小。hidden_size :整數。變壓器中隱藏層的維度。num_heads :整數。自我注意力機制中的注意力頭數。rotary_dim :整數或無。用於旋轉位置編碼的尺寸數量。intermediate_size :整數。饋送網絡中中間層的大小。layer_norm_epsilon :float。較小的常數以防止層歸一化零分裂。默認值為1e-5 。activation_function :字符串。激活函數用於進料向前網絡。默認值為'gelu' 。attn_pdrop :float。注意機制的輟學概率。默認值為0.0 。resid_pdrop :浮動。殘留連接的輟學概率。默認值為0.0 。max_position_embeddings :整數。最大使用位置嵌入的數量。默認值為1024 。dtype :jnp.dtype。用於計算的數據類型。默認值為jnp.float32 。causal :布爾。是否使用因果(自動回歸)模式。默認是True 。policy :字符串。檢查點梯度的策略。默認值是'nothing_saveable' 。prevent_cse :布爾值。是否預防常見的亞表達消除(CSE)。默認值為False 。float32_logits :布爾值。是否將Float32用於邏輯計算。默認值為False 。方法
BlockwiseParallelJax類的主要方法是forward方法,該方法執行變壓器塊的正向通行。
forward ( hidden_states , attention_mask , position_ids , deterministic = True , init_cache = False )hidden_states :jnp.ndarray。變壓器塊的輸入張量。它應該具有形狀(batch_size, sequence_length, hidden_size) 。attention_mask :jnp.ndarray。自我發項機制的注意力掩蓋。它應該具有形狀(batch_size, 1, 1, sequence_length) 。position_ids :jnp.ndarray。位置編碼的位置ID。它應該具有形狀(1, sequence_length) 。deterministic :布爾值。是否使用確定性模式(無輟學模式)。默認是True 。init_cache :布爾值。是否要初始化緩存以進行快速解碼。默認值為False 。此方法返回變壓器塊的輸出張量,該塊具有與hidden_states相同的形狀。
示例用法
以下示例演示瞭如何使用BlockwiseParallelJax類。
# Initialize
from jax import random
import jax . numpy as jnp
from AttentionGrid import BlockwiseParallelJax
# Initialize transformer block
block = BlockwiseParallelJax (
q_chunk_size = 64 ,
k_chunk_size = 64 ,
hidden_size = 768 ,
num_heads = 12 ,
rotary_dim = 64 ,
intermediate_size = 3072 ,
)
# Create a batch of input tensors
key = random . PRNGKey ( 0 )
batch_size = 8
sequence_length = 128
hidden_states = random . normal ( key , ( batch_size , sequence_length , block . hidden_size ))
# Create attention mask
attention_mask = jnp . ones (( batch_size , 1 , 1 , sequence_length ))
# Create position ids
position_ids = jnp . arange ( sequence_length )[ None , :]
# Forward pass
output = block . forward ( hidden_states , attention_mask , position_ids )
print ( output . shape ) # prints: (8, 128, 768) FusedLandmarkAttention這是一個Pytorch Function類,它封裝了融合具有里程碑意義的注意機制的前向和向後功能。
forward(ctx, q, k, v, n_prefix_q, sm_scale, block_size)此功能執行了融合地標的正向通行證。
ctx :我們可以保存變量以在向後通過的對象。由Pytorch的自動克拉德系統提供。q :查詢張量。假定它是連續的,其形狀應為(批處理,n頭,seqlen_q,d)。k :鑰匙張量。假定它是連續的,其形狀應與Q的形狀相匹配,即(批次,nheads,seqlen_k,d)。v :值張量。假定它是連續的,其形狀應與Q和K的形狀相匹配,即(批次,nheads,seqlen_k,d)。n_prefix_q :查詢中的前綴數。sm_scale :SoftMax操作中使用的縮放係數。block_size :用於執行塊操作的塊大小。 o :來自融合具有里程碑意義的注意機制的前向傳球的輸出張量。 backward(ctx, do)此功能執行融合地標註意的向後通過,即計算梯度。
ctx :我們可以從該對像中檢索保存在正向通行中的變量。由Pytorch的自動克拉德系統提供。do :相對於正向函數的輸出,損失的梯度。 None 。fused_landmark_attention(q, k, v, is_mem, sm_scale=None, block_size=64)此功能是FusedLandmarkAttention類的方便包裝器。
q :查詢張量。k :鑰匙張量。v :值張量。is_mem :一個布爾張量,指示是否應將每個鍵值對作為內存視為。它的長度應與鍵的序列長度相同。sm_scale :SoftMax操作中使用的縮放係數。如果None ,它將設置為1.0 / sqrt(d) 。block_size :用於執行塊操作的塊大小。 這是如何使用fused_landmark_attention函數的基本示例。
import torch
from AttentionGrid import fused_landmark_attention
# Initialize some tensors
batch = 8
nheads = 12
seqlen = 128
d = 64
q = torch . randn ( batch , nheads , seqlen , d )
k = torch . randn ( batch , nheads , seqlen , d )
v = torch . randn ( batch , nheads , seqlen , d )
is_mem = torch . zeros ( seqlen , dtype = torch . bool )
# Call the function
output = fused_landmark_attention ( q , k , v , is_mem )
print ( output . shape ) # prints: (8, 12, 128, 64)此示例首先初始化一些張量以作為查詢,鍵
和值。然後,它稱為fused_landmark_attention函數,並打印輸出張量的形狀。
import torch
import torch . nn as nn
from AttentionGrid import DilatedAttention
# Replace this with your correct GPU device
device = "cuda:0"
dtype = torch . float16
# Create an instance of DilatedAttention
d_model = 512
num_heads = 8
dilation_rate = 2
segment_size = 64
dropout = 0.2 # Specify the dropout rate
attention = DilatedAttention (
d_model = d_model ,
num_heads = num_heads ,
dilation_rate = dilation_rate ,
segment_size = segment_size ,
dropout = dropout ,
). to ( device , dtype = dtype )
# Create some dummy input data
batch_size = 16
seq_len = 128
input_dim = d_model
inputs = torch . randn ( batch_size , seq_len , input_dim , device = device , dtype = dtype )
# Forward pass
outputs = attention ( inputs )
# Print the output shape
print ( outputs . shape ) # Expected: [batch_size, seq_len, d_model]在上面的示例中,我們使用指定的超參數創建一個DilatedAttention類別的實例。然後,我們生成一些虛擬輸入數據,並將其通過注意機制以獲取輸出。最後,我們打印輸出張量的形狀。
DilatedAttention類具有擴張的注意力,隨著令牌之間的距離的增長,該階段會呈指數增長。它從torch.nn.Module繼承,可以用作變壓器模型中標準注意機制的倒入替換。
d_model (INT):輸入和輸出嵌入的維度。num_heads (int):注意力頭的數量。dilation_rate (INT):稀疏輸入序列的擴張率。segment_size (int):稀疏後每個段的大小。dropout (浮點,可選):輟學概率應用於注意力輸出。默認值:0.0(無輟學)。x (張量):形狀的輸入張量(batch_size, seq_len, d_model) 。output (張量):Shape的輸出張量(batch_size, seq_len, d_model) 。請注意,輸入張量應在正確的設備(例如GPU)上,並具有適當的數據類型( dtype )。