
更新(4/8/2024) :JORA现在支持Google的Gemma模型。
更新(4/11/2024) :Gemma 1.1添加了支持
大型语言模型(LLMS)的基于检索任务的缩放,尤其是在检索增强发电(RAG)中,面临着重大的内存限制,尤其是在微调广泛的及时序列时。当前的开源库支持多个GPU的全模型推理和微调,但缺乏适应检索到上下文所需的有效参数分布。在解决这一差距的情况下,我们引入了一个新颖的框架,用于利用分布式培训的Llama-2模型的PEFT兼容微调。我们的框架独特地利用了JAX的Jax(JIT)汇编和张量缩减来进行有效的资源管理,从而促进了随着内存要求减少的加速微调。这种进步显着提高了微调LLMs对复杂的抹布应用的可扩展性和可行性,即使在GPU资源有限的系统上也是如此。我们的实验表明,与拥抱面部/深速实施相比,运行时的提高超过12倍,而四个GPU则消耗了少于每GPU的VRAM一半。
请确保您安装了最新版本的JAX。 https://github.com/google/jax
要安装软件包,请在存储库的根目录中运行以下命令:
git clone https://github.com/aniquetahir/JORA.git
cd JORA
pip install -e .确保JAX可以访问GPU:
import jax
print ( jax . devices ())可以通过python或提供GUI来使用该库。
Calallama类可用于定义配置。明智的参数设置为默认值。
class ParallamaConfig ( NamedTuple ):
JAX_PARAMS_PATH : str
LLAMA2_META_PATH : str # e.g. '/tmp/llama2-13B'
MODEL_SIZE : str # '7B', '13B', '70B'
NUM_GPUS : int = None
LORA_R : int = 16
LORA_ALPHA : int = 16
LORA_DROPOUT : float = 0.05
LR : float = 0.0001
BATCH_SIZE : int = 1
N_ACCUMULATION_STEPS : int = 8
MAX_SEQ_LEN = 2000
N_EPOCHS : int = 7
SEED : int = 420 基于Llama-2的模型
from jora import train_lora , ParallamaConfig , generate_alpaca_dataset
config = ParallamaConfig ( MODEL_SIZE = model_size , JAX_PARAMS_PATH = jax_path ,
LLAMA2_META_PATH = hf_path )
dataset = generate_alpaca_dataset ( dataset_path , 'train' , config )
train_lora ( config , dataset , checkpoint_path )可以从Kaggle下载基于Gemma的模型Flax Gemma模型:
import kagglehub
VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it', '1.1-2b-it', '1.1-7b-it'] {type:"string"}
weights_dir = kagglehub . model_download ( f'google/gemma/Flax/ { VARIANT } ' )默认情况下,KaggleHub将模型存储在~/.cache/kagglehub目录中。
from jora import ParagemmaConfig , train_lora_gemma , generate_alpaca_dataset_gemma
# model version in '2b', '2b-it', '7b', '7b-it' (2b-it, 7b-it for Gemma 1.1)
config = ParagemmaConfig ( GEMMA_MODEL_PATH = model_path , MODEL_VERSION = model_version )
dataset = generate_alpaca_dataset_gemma ( dataset_path , 'train' , config )
train_lora_gemma ( config , dataset , checkpoint_path )== Gemma 1.1 ==
对于Gemma 1.1模型,KaggleHub将模型存储在以下目录结构中:
1.1-7b-it
├── 1
│ ├── 7b-it
│ └── tokenizer.model
└── 1.complete
因此,应将1.1-7b-it模型的config.MODEL_VERSION设置为7b-it 。
generate_alpaca_dataset函数用于从羊驼格式JSON文件中生成数据集。这有助于进行指导格式培训,因为数据集处理,令牌化和批处理由图书馆处理。另外,火炬Dataset和DataLoader可用于自定义数据集。
Huggingface具有广泛的生态系统。由于我们的图书馆使用JAX进行培训,因此由此产生的模型不兼容。为了解决此问题,我们提供了一个子模块,以将经过JAX训练的模型转换回HuggingFace格式。
训练有素的洛拉权重可以首先与原始参数合并:
SYNOPSIS
python -m jora.lora.merge PARAMS_PATH LORA_PATH OUTPUT_PATH <flags>
POSITIONAL ARGUMENTS
PARAMS_PATH
Type: str
LORA_PATH
Type: str
OUTPUT_PATH
Type: str
FLAGS
-l, --llama2=LLAMA2
Default: False
-g, --gemma=GEMMA
Default: False
NOTES
You can also use flags syntax for POSITIONAL ARGUMENTS
SYNOPSIS
python -m jora.hf HUGGINGFACE_PATH JAX_PATH SAVE_PATH
DESCRIPTION
This function takes a huggingface llama model and replaces the q_proj and v_proj weights with the lora merged weights
POSITIONAL ARGUMENTS
HUGGINGFACE_PATH
Type: str
path to the huggingface llama model
JAX_PATH
Type: str
path to the lora merged params
SAVE_PATH
Type: str
path to save the updated huggingface llama model
NOTES
You can also use flags syntax for POSITIONAL ARGUMENTS
GUI可用于训练模型。 GUI是通过运行以下命令开始的:
python -m jora.gui| GPU | 1 | 2 | 4 | |
|---|---|---|---|---|
| 带有Microsoft DeepSpeed Zero-3的拥抱脸PEFT | mem(MB) | 20645.2(39.81) | 23056 /23024(14.63 / 29.29) | 23978 /23921 /23463 /23397(47.87 / 50.39 / 31.96 / 17.46) |
| 性能(SEC) | 4.56(0.04) | 2.81(0.02) | 5.45(0.09) | |
| 乔拉(我们的) | mem(MB) | 23102(0.00) | 16068 /16008(0.00 / 0.00) | 11460 /11448 /11448 /11400(0.0 / 0.00 / 0.00 / 0.00) |
| 性能(SEC) | 0.19(0.00) | 0.79(0.00) | 0.44(0.00) |
在几个地方,捐款将受到赞赏。
@misc { tahir2024jora ,
title = { JORA: JAX Tensor-Parallel LoRA Library for Retrieval Augmented Fine-Tuning } ,
author = { Anique Tahir and Lu Cheng and Huan Liu } ,
year = { 2024 } ,
eprint = { 2403.11366 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}JAX Llama-2模型实施AYAKA14732
Google DeepMind的亚麻Gemma模型实施