
Felafax는 XLA 런타임을 사용하여 지속적인 훈련 및 미세 조정 오픈 소스 LLM을위한 프레임 워크입니다. 우리는 필요한 런타임 설정을 처리하고 시작하기 위해 Jupyter 노트북을 제공합니다.
Felafax의 목표는 Nonvidia 하드웨어 (TPU, AWS Trainium, AMD GPU 및 Intel GPU)에서 AI 워크로드를보다 쉽게 실행할 수 있도록 Infra를 구축하는 것입니다.
데이터 세트를 추가하고 "모두 실행"을 클릭하면 Google Colab에서 무료 TPU 리소스에서 실행됩니다!
| Felafax 지원 | 무료 노트북 |
|---|---|
| 라마 3.1 (1B, 3B) |
LLAMA-3.1 JAX 구현
LLAMA-3/3.1 PYTORCH XLA
몇 가지 간단한 단계로 Felafax CLI를 사용하여 모델을 미세 조정하기 시작하십시오.
CLI를 설치하여 시작하십시오.
pip install pipx
pipx install felafax-cli그런 다음 인증 토큰을 생성합니다.
마지막으로 토큰을 사용하여 CLI 세션을 인증하십시오.
felafax-cli auth login --token < your_token > 먼저 미세 조정을위한 기본 구성 파일을 생성하십시오. 이 명령은 기본 하이퍼 파라미터 값으로 현재 디렉토리에서 config.yml 파일을 생성합니다.
felafax-cli tune init-config둘째, 하이퍼 파라미터로 구성 파일을 업데이트하십시오.
포옹 페이스 손잡이 :
데이터 세트 파이프 라인 및 교육 매개 변수 :
batch_size , max_seq_length 를 조정하십시오.null 로 설정하십시오. num_steps가 숫자로 설정되면 지정된 수의 단계 후에 훈련이 중지됩니다.learning_rate 및 lora_rank 설정하십시오.eval_interval 은 평가 사이의 단계 수입니다.다음 명령을 실행하려면 미세 조정할 수있는 기본 모델 목록을 확인하십시오. 현재 LLAMA-3.1의 모든 변형을 지원합니다.
felafax-cli tune start --help 이제 위 목록에서 선택한 모델과 Huggingface의 데이터 세트 이름으로 미세 조정 프로세스 yahma/alpaca-cleaned 시작할 수 있습니다.
felafax-cli tune start --model < your_selected_model > --config ./config.yml --hf-dataset-id < your_hf_dataset_name >예제 명령을 시작하려면 :
felafax-cli tune start --model llama3-2-1b --config ./config.yml --hf-dataset-id yahma/alpaca-cleaned미세 조정 작업을 시작한 후 Felafax CLI는 TPU를 회전시키고 교육을 실행하고 미세 조정 모델을 Huggingface 허브에 업로드합니다.
실시간 로그를 스트리밍하여 미세 조정 작업의 진행 상황을 모니터링 할 수 있습니다.
# Use `<job_name>` with the job namethat you get after starting the fine-tuning.
felafax-cli tune logs --job-id < job_name > -f미세 조정이 완료되면 모든 미세 조정 모델을 나열 할 수 있습니다.
felafax-cli model list대화 형 터미널 세션을 시작하여 미세 조정 된 모델과 채팅 할 수 있습니다.
# Replace `<model_id>` with model id from `model list` command you ran above.
felafax-cli model chat --model-id < model_id > CLI는 세 가지 주요 명령 그룹으로 나뉩니다.
tune : 미세 조정 작업을 시작/중지합니다.model : 미세 조정 된 모델을 관리하고 상호 작용합니다.files : DatASet 파일을 업로드/보기. --help 플래그를 사용하여 모든 명령 그룹에 대해 자세히 알아보십시오.
felafax-cli tune --help우리는 최근 Pytorch 대신 Jax를 사용하여 8xAMD MI300X GPU에서 LLAMA3.1 405B 모델을 미세 조정했습니다. Jax의 Advanced Sharding API를 통해 우리는 훌륭한 성능을 달성 할 수있었습니다. 블로그 게시물을 확인하여 우리가 사용한 설정 및 샤드 트릭에 대해 알아보십시오.
우리는 Bfloat16 정밀도의 모든 모델 가중치 및 LORA 매개 변수로 LORA 미세 조정을했고 LORA 순위는 8, Lora Alpha는 16입니다.
GPU 활용 및 VRAM 사용법 그래프는 아래에서 찾을 수 있습니다. 그러나 여전히 MFU (Model Flops Utilization)를 계산해야합니다. 참고 : 인프라 및 VRAM 제약으로 인해 405B 모델의 JIT 컴파일 버전을 실행할 수 없었습니다 (이를 더 조사해야 함). 전체 교육 실행은 JAX Eger 모드에서 실행되었으므로 성능 향상에 대한 잠재력이 상당합니다.


궁금한 점이 있으면 설립자@felafax.ai로 문의하십시오.