Selamat datang! Untuk proyek baru saya sekarang sangat merekomendasikan menggunakan proyek Jaxtyping baru saya sebagai gantinya. Ini mendukung Pytorch, tidak benar -benar bergantung pada Jax, dan tidak seperti torchtyping itu kompatibel dengan tipe checker statis. :)
Putar ini:
def batch_outer_product ( x : torch . Tensor , y : torch . Tensor ) -> torch . Tensor :
# x has shape (batch, x_channels)
# y has shape (batch, y_channels)
# return has shape (batch, x_channels, y_channels)
return x . unsqueeze ( - 1 ) * y . unsqueeze ( - 2 )ke dalam ini:
def batch_outer_product ( x : TensorType [ "batch" , "x_channels" ],
y : TensorType [ "batch" , "y_channels" ]
) -> TensorType [ "batch" , "x_channels" , "y_channels" ]:
return x . unsqueeze ( - 1 ) * y . unsqueeze ( - 2 )dengan pemeriksaan terprogram bahwa spesifikasi bentuk (dtype, ...) terpenuhi.
Bye-bye bug! Sapa dokumentasi kode Anda yang ditegakkan dan jelas.
Jika (seperti saya) Anda mendapati diri Anda mengotori kode Anda dengan komentar seperti # x has shape (batch, hidden_state) atau pernyataan seperti assert x.shape == y.shape , hanya untuk melacak apa bentuk semuanya, maka ini untuk Anda.
pip install torchtypingMembutuhkan Python> = 3.7 dan Pytorch> = 1.7.0.
Jika menggunakan typeguard maka itu harus menjadi versi <3.0.0.
torchtyping memungkinkan untuk anotasi tipe:
... ;torchtyping sangat dapat diperluas. Jika typeguard (opsional) diinstal maka pada saat runtime jenis dapat diperiksa untuk memastikan bahwa tensor benar -benar dari bentuk yang diiklankan, dType, dll.
# EXAMPLE
from torch import rand
from torchtyping import TensorType , patch_typeguard
from typeguard import typechecked
patch_typeguard () # use before @typechecked
@ typechecked
def func ( x : TensorType [ "batch" ],
y : TensorType [ "batch" ]) -> TensorType [ "batch" ]:
return x + y
func ( rand ( 3 ), rand ( 3 )) # works
func ( rand ( 3 ), rand ( 1 ))
# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3. typeguard juga memiliki kait impor yang dapat digunakan untuk menguji seluruh modul secara otomatis, tanpa perlu menambahkan @typeguard.typechecked secara manual.TypeChecked Decorator.
Jika Anda tidak menggunakan typeguard maka torchtyping.patch_typeguard() dapat dihilangkan sama sekali, dan torchtyping hanya digunakan untuk tujuan dokumentasi. Jika Anda belum menggunakan typeguard untuk pemrograman Python biasa, maka sangat pertimbangkan untuk menggunakannya. Ini cara yang bagus untuk merobohkan bug. Baik typeguard dan torchtyping juga berintegrasi dengan pytest , jadi jika Anda khawatir tentang penalti kinerja maka mereka dapat diaktifkan selama tes saja.
torchtyping . TensorType [ shape , dtype , layout , details ]Inti perpustakaan.
Setiap shape , dtype , layout , details adalah opsional.
shape bisa salah satu dari:int : Dimensi harus seukuran ini. Jika -1 maka ukuran apa pun diizinkan.str : Ukuran dimensi yang dilewati saat runtime akan terikat pada nama ini, dan semua tensor memeriksa bahwa ukurannya konsisten.... : Sejumlah dimensi sewenang -wenang dari ukuran apa pun.str: int (secara teknis itu adalah irisan), menggabungkan perilaku str dan int . (Hanya str sendiri setara dengan str: -1 .)str: str , dalam hal ini ukuran dimensi yang dilewati saat runtime akan terikat pada kedua nama, dan semua dimensi dengan kedua nama harus memiliki ukuran yang sama. (Beberapa orang suka menggunakan ini sebagai cara untuk mengaitkan beberapa nama dengan dimensi, untuk tujuan dokumentasi tambahan.)str: ... pasangan, dalam hal ini beberapa dimensi yang sesuai dengan ... akan terikat pada nama yang ditentukan oleh str , dan sekali lagi diperiksa konsistensi antara argumen.None , yang bila digunakan bersama dengan is_named di bawah ini, menunjukkan dimensi yang tidak boleh memiliki nama dalam arti tensor bernama.None: int pair, menggabungkan perilaku None dan int . (Hanya None sendiri yang setara dengan None: -1 .)None: str , menggabungkan perilaku None dan str . (Artinya, itu tidak boleh memiliki dimensi bernama, tetapi harus berukuran konsisten dengan penggunaan string lainnya.)typing.Any : Ukuran apa pun diizinkan untuk dimensi ini (setara dengan -1 ).TensorType["batch": ..., "length": 10, "channels", -1] . Jika Anda hanya ingin menentukan jumlah dimensi maka gunakan misalnya TensorType[-1, -1, -1] untuk tensor tiga dimensi.dtype bisa jadi salah satu dari:torch.float32 , torch.float64 dll.int , bool , float , yang dikonversi menjadi tipe Pytorch yang sesuai. float secara khusus ditafsirkan sebagai torch.get_default_dtype() , yang biasanya float32 .layout dapat berupa torch.strided atau torch.sparse_coo , untuk masing -masing tegang yang padat dan jarang.details menawarkan cara untuk melewati sejumlah bendera tambahan yang sewenang -wenang yang menyesuaikan dan memperluas torchtyping . Dua bendera dibangun secara default. torchtyping.is_named menyebabkan nama -nama dimensi tensor diperiksa, dan torchtyping.is_float dapat digunakan untuk memeriksa apakah tipe floating point yang sewenang -wenang dilewatkan. (Daripada hanya yang spesifik seperti halnya TensorType[torch.float32] .) Untuk diskusi tentang bagaimana torchtyping details Anda sendiri.[] . Misalnya TensorType["batch": ..., "length", "channels", float, is_named] . torchtyping . patch_typeguard () torchtyping terintegrasi dengan typeguard untuk melakukan pemeriksaan jenis runtime. torchtyping.patch_typeguard() harus dipanggil di tingkat global, dan akan menambal typeguard untuk memeriksa TensorType s.
Fungsi ini aman untuk dijalankan beberapa kali. (Tidak ada apa -apa setelah lari pertama).
@typeguard.typechecked , maka torchtyping.patch_typeguard() harus dipanggil kapan saja sebelum menggunakan @typeguard.typechecked . Misalnya Anda bisa menyebutnya di awal setiap file menggunakan torchtyping .typeguard.importhook.install_import_hook , maka torchtyping.patch_typeguard() harus dipanggil kapan saja sebelum mendefinisikan fungsi yang ingin Anda periksa. Misalnya Anda dapat memanggil torchtyping.patch_typeguard() hanya sekali, pada saat yang sama dengan pengait impor typeguard . (Urutan kait dan tambalan tidak masalah.)typeguard maka torchtyping.patch_typeguard() dapat dihilangkan sama sekali, dan torchtyping hanya digunakan untuk tujuan dokumentasi. pytest --torchtyping-patch-typeguard torchtyping menawarkan plugin pytest untuk secara otomatis menjalankan torchtyping.patch_typeguard() sebelum pengujian Anda. pytest akan secara otomatis menemukan plugin, Anda hanya perlu melewati --torchtyping-patch-typeguard bendera untuk memungkinkannya. Paket kemudian dapat diteruskan ke typeguard seperti biasa, baik dengan menggunakan @typeguard.typechecked , hook impor typeguard , atau bendera pytest --typeguard-packages="your_package_here" .
Lihat dokumentasi lebih lanjut untuk:
flake8 dan mypy ;torchtyping ;