Aktueller CI -Status:
Pytorch/XLA ist ein Python -Paket, das den XLA Deep Learning Compiler verwendet, um den Pytorch Deep Learning Framework und den Cloud -TPUs zu verbinden. Sie können es jetzt kostenlos auf einer einzelnen Cloud -TPU -VM mit Kaggle ausprobieren!
Schauen Sie sich eines unserer Kaggle -Notizbücher an, um loszulegen:
So installieren Sie Pytorch/XLA -stabilen Build in einem neuen TPU -VM:
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html
So installieren Sie Pytorch/XLA Nightly Build in einem neuen TPU -VM:
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
Pytorch/XLA bietet jetzt GPU -Unterstützung über ein Plugin -Paket, libtpu ähnelt:
pip install torch~=2.5.0 torch_xla~=2.5.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.5.0-py3-none-any.whl
Um Ihre vorhandene Trainingsschleife zu aktualisieren, nehmen Sie die folgenden Änderungen vor:
- import torch.multiprocessing as mp
+ import torch_xla as xla
+ import torch_xla.core.xla_model as xm
def _mp_fn(index):
...
+ # Move the model paramters to your XLA device
+ model.to(xla.device())
for inputs, labels in train_loader:
+ with xla.step():
+ # Transfer data to the XLA device. This happens asynchronously.
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
- optimizer.step()
+ # `xm.optimizer_step` combines gradients across replicas
+ xm.optimizer_step(optimizer)
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ # xla.launch automatically selects the correct world size
+ xla.launch(_mp_fn, args=()) Wenn Sie DistributedDataParallel verwenden, führen Sie die folgenden Änderungen vor:
import torch.distributed as dist
- import torch.multiprocessing as mp
+ import torch_xla as xla
+ import torch_xla.distributed.xla_backend
def _mp_fn(rank):
...
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
- dist.init_process_group("gloo", rank=rank, world_size=world_size)
+ # Rank and world size are inferred from the XLA device runtime
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ # `gradient_as_bucket_view=True` required for XLA
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
for inputs, labels in train_loader:
+ with xla.step():
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ xla.launch(_mp_fn, args=())Weitere Informationen zu Pytorch/XLA, einschließlich einer Beschreibung seiner Semantik und Funktionen, finden Sie unter pytorch.org. Sehen Sie sich den API -Leitfaden für Best Practices beim Schreiben von Netzwerken an, die auf XLA -Geräten ausgeführt werden (TPU, CUDA, CPU und ...).
Unsere umfassenden Benutzerführer sind verfügbar unter:
Dokumentation für die neueste Version
Dokumentation für Master Branch
Die Pytorch/XLA -Veröffentlichungen beginnen mit der Version R2.1 auf PYPI. Sie können jetzt den Hauptbau mit pip install torch_xla . So installieren Sie auch das Cloud -TPU -Plugin, das Ihrer tpu torch_xla entspricht
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
GPU und nächtliche Builds sind in unserem öffentlichen GCS -Eimer erhältlich.
| Version | Cloud GPU VM Räder |
|---|---|
| 2,5 (CUDA 12,1 + Python 3,9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
| 2,5 (CUDA 12,1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2,5 (CUDA 12,1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
| 2,5 (CUDA 12,4 + Python 3,9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
| 2,5 (CUDA 12,4 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2,5 (CUDA 12,4 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
| Nacht (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp38-cp38-linux_x86_64.whl |
| Nacht (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl |
| Nacht (CUDA 12,1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.6.0.dev-cp38-cp38-linux_x86_64.whl |
pip3 install torch==2.6.0.dev20240925+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly%2B20240925-cp310-cp310-linux_x86_64.whl
Die Torch Wheel Version 2.6.0.dev20240925+cpu finden Sie unter https://download.pytorch.org/whl/nightly/torch/.
Sie können auch yyyymmdd nach torch_xla-2.6.0.dev hinzufügen, um das nächtliche Rad eines bestimmten Datums zu erhalten. Hier ist ein Beispiel:
pip3 install torch==2.5.0.dev20240820+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240820-cp310-cp310-linux_x86_64.whl
Die Torch Wheel Version 2.6.0.dev20240925+cpu finden Sie unter https://download.pytorch.org/whl/nightly/torch/.
| Version | Wolken -TPU -VMS -Rad |
|---|---|
| 2.4 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.3 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.2 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.1 (XRT + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.1 (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl |
| Version | GPU -Rad |
|---|---|
| 2,5 (CUDA 12,1 + Python 3,9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
| 2,5 (CUDA 12,1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2,5 (CUDA 12,1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
| 2,5 (CUDA 12,4 + Python 3,9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
| 2,5 (CUDA 12,4 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2,5 (CUDA 12,4 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
| 2,4 (CUDA 12,1 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl |
| 2,4 (CUDA 12,1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.4 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl |
| 2,3 (CUDA 12,1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl |
| 2,3 (CUDA 12,1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2,3 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl |
| 2,2 (CUDA 12,1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl |
| 2.2 (CUDA 12,1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
| 2.1 + cuda 11.8 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl |
| Nightly + CUDA 12.0> = 2023/06/27 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
| Version | Cloud TPU VMS Docker |
|---|---|
| 2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_tpuvm |
| 2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm |
| 2.3 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm |
| 2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm |
| 2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm |
| Nightly Python | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm |
Um die oben genannten Docker zu verwenden, passieren Sie bitte --privileged --net host --shm-size=16G entlang. Hier ist ein Beispiel:
docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash| Version | GPU CUDA 12.4 Docker |
|---|---|
| 2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4 |
| 2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.4 |
| Version | GPU CUDA 12.1 Docker |
|---|---|
| 2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.1 |
| 2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.1 |
| 2.3 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1 |
| 2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1 |
| 2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1 |
| Nacht- | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1 |
| Nachts zum Zeitpunkt | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1_YYYYMMDD |
| Version | GPU CUDA 11.8 + Docker |
|---|---|
| 2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8 |
| 2.0 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.0_3.8_cuda_11.8 |
Recheninstanzen mit GPUs ausführen.
Wenn Pytorch/XLA nicht wie erwartet funktioniert, finden Sie in der Fehlerbehebung, die Vorschläge zum Debuggen und Optimieren Ihres Netzwerks enthält.
Das Pytorch/XLA -Team freut sich immer, von Benutzern und OSS -Mitwirkenden zu hören! Der beste Weg, um zu erreichen, besteht darin, ein Problem in diesem GitHub einzureichen. Fragen, Fehlerberichte, Feature -Anfragen, Erstellen von Problemen usw. sind alle willkommen!
Siehe den Beitragsleitfaden.
Dieses Repository wird gemeinsam von Google, Meta und einer Reihe von individuellen Mitwirkenden betrieben und verwaltet, die in der Datei der Mitwirkenden aufgeführt sind. Für Fragen, die bei Meta angegeben sind, senden Sie bitte eine E -Mail an [email protected]. Für Fragen, die bei Google gerichtet sind, senden Sie bitte eine E-Mail an [email protected]. Für alle anderen Fragen öffnen Sie hier ein Problem in diesem Repository.
Sie können zusätzliche nützliche Lesematerialien finden in