
Felafax ist ein Framework für die Fortsetzung und Feinabstimmung von Open Source LLMs mit XLA-Laufzeit . Wir kümmern uns um das notwendige Laufzeit-Setup und stellen ein Jupyter-Notizbuch außerhalb des Boxs zur Verfügung, um gerade loszulegen.
Unser Ziel bei Felafax ist es, Infra zu bauen, um die Durchführung von KI-Workloads auf Nicht-Nvidia-Hardware (TPU, AWS Trainium, AMD GPUs und Intel GPUs) zu erleichtern.
Fügen Sie Ihren Datensatz hinzu, klicken Sie auf "Alle ausführen" und Sie werden auf der kostenlosen TPU -Ressource in Google Colab ausgeführt!
| Felafax unterstützt | Kostenlose Notizbücher |
|---|---|
| Lama 3.1 (1b, 3b) |
LAMA-3.1 JAX-Implementierung
Lama-3/3.1 Pytorch Xla
Beginnen Sie mit der Feinabstimmung Ihrer Modelle mit der Felafax CLI in wenigen einfachen Schritten.
Beginnen Sie mit der Installation der CLI.
pip install pipx
pipx install felafax-cliGenerieren Sie dann ein Auth -Token:
Schließlich authentifizieren Sie Ihre CLI -Sitzung mit Ihrem Token:
felafax-cli auth login --token < your_token > Generieren Sie zunächst eine Standardkonfigurationsdatei für die Feinabstimmung. Dieser Befehl generiert eine config.yml -Datei im aktuellen Verzeichnis mit Standard -Hyperparameterwerten.
felafax-cli tune init-configZweitens aktualisieren Sie die Konfigurationsdatei mit Ihren Hyperparametern:
Umarmungsface -Knöpfe:
Datensatzpipeline und Trainingsparams:
batch_size , max_seq_length an, um für den Feinabstimmungsdatensatz zu verwenden.null fest, wenn Sie möchten, dass Trainig den gesamten Datensatz durchläuft. Wenn num_steps auf eine Nummer eingestellt ist, wird das Training nach der angegebenen Anzahl von Schritten gestoppt.learning_rate und lora_rank für Feinabstimmungen.eval_interval ist die Anzahl der Schritte zwischen Bewertungen.Führen Sie den Befehl folgen aus, um die Liste der Basismodelle anzuzeigen, die Sie gut abschneiden können. Wir unterstützen alle Varianten von Lama-3.1 ab sofort.
felafax-cli tune start --help Jetzt können Sie den Feinabstimmungsvorgang mit Ihrem ausgewählten Modell aus der obigen Liste und dem Datasetnamen von Huggingface (wie yahma/alpaca-cleaned ) starten:
felafax-cli tune start --model < your_selected_model > --config ./config.yml --hf-dataset-id < your_hf_dataset_name >Beispielbefehl, um Ihnen den Einstieg zu erleichtern:
felafax-cli tune start --model llama3-2-1b --config ./config.yml --hf-dataset-id yahma/alpaca-cleanedNachdem Sie den Feinabstimmungsjob gestartet haben, kümmert sich Felafax CLI darum, den TPUs zu spinnen, das Training auszuführen, und lädt das fein abgestimmte Modell in den Hubface-Hub hoch.
Sie können Echtzeit-Protokolle streamen, um den Fortschritt Ihres Feinabstimmungsjobs zu überwachen:
# Use `<job_name>` with the job namethat you get after starting the fine-tuning.
felafax-cli tune logs --job-id < job_name > -fNach Abschluss der Feinabstimmung können Sie alle Ihre fein abgestimmten Modelle auflisten:
felafax-cli model listSie können eine interaktive Terminalsitzung starten, um mit Ihrem fein abgestimmten Modell zu chatten:
# Replace `<model_id>` with model id from `model list` command you ran above.
felafax-cli model chat --model-id < model_id > Die CLI ist in drei Hauptbefehlsgruppen unterteilt:
tune : Starten/Beendigung von Arbeitsplätzen.model : Verwalten und Interaktion mit Ihren fein abgestimmten Modellen.files : Zum Hochladen/Anzeigen von Datendateien. Verwenden Sie das Flag --help , um mehr über jede Befehlsgruppe zu erfahren:
felafax-cli tune --helpWir haben kürzlich das LLAMA3.1 405B-Modell am 8xAMD MI300X GPUs mit JAX anstelle von Pytorch abgestimmt. Jax 'fortgeschrittene Sharding -APIs ermöglichte es uns, eine große Leistung zu erzielen. Schauen Sie sich unseren Blog -Beitrag an, um sich über das Setup und die von uns verwendeten Sharding -Tricks zu informieren.
Wir haben Lora mit allen Modellgewichten und Lora-Parametern in Bfloat16-Präzision und mit Lora-Rang von 8 und Lora Alpha von 16:
Die GPU -Nutzungs- und VRAM -Nutzungsdiagramme finden Sie unten. Wir müssen jedoch immer noch die Modellflops -Nutzung (MFU) berechnen. Hinweis: Wir konnten die JIT-kompilierte Version des 405B-Modells aufgrund von Infrastruktur- und VRAM-Einschränkungen nicht ausführen (wir müssen dies weiter untersuchen). Der gesamte Trainingslauf wurde im JAX -Eager -Modus ausgeführt, sodass ein erhebliches Potenzial für Leistungsverbesserungen besteht.


Wenn Sie Fragen haben, kontaktieren Sie uns bitte unter [email protected].