Construit avec Jax et Pint!
Ce module fournit une interface entre Jax et Pint pour permettre à Jax de prendre en charge les opérations avec des unités. La propagation des unités se produit au moment de la trace, donc les fonctions JiTe ne devraient voir aucun coût d'exécution. Cette bibliothèque est expérimentale, alors attendez-vous à des arêtes vives.
Par exemple:
>> > import jax
>> > import jax . numpy as jnp
>> > import jpu
>> >
>> > u = jpu . UnitRegistry ()
>> >
>> > @ jax . jit
... def add_two_lengths ( a , b ):
... return a + b
...
>> > add_two_lengths ( 3 * u . m , jnp . array ([ 4.5 , 1.2 , 3.9 ]) * u . cm )
< Quantity ([ 3.045 3.012 3.039 ], 'meter' ) > Pour installer, utilisez pip :
python -m pip install jpu Les seules dépendances sont jax et pint , et celles-ci seront également installées, sinon déjà dans votre environnement. Jetez un œil aux JAX Docs pour plus d'informations sur l'installation de JAX sur différents systèmes.
Voici un exemple un peu plus complet:
>> > import jax
>> > import numpy as np
>> > from jpu import UnitRegistry , numpy as jnpu
>> >
>> > u = UnitRegistry ()
>> >
>> > @ jax . jit
... def projectile_motion ( v_init , theta , time , g = u . standard_gravity ):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu . cos ( theta )
... y = v_init * time * jnpu . sin ( theta ) - 0.5 * g * jnpu . square ( time )
... return x . to ( u . m ), y . to ( u . m )
...
>> > x , y = projectile_motion (
... 5.0 * u . km / u . h , 60 * u . deg , np . linspace ( 0 , 1 , 50 ) * u . s
... ) La limitation la plus importante de cette bibliothèque est le fait que les utilisateurs doivent utiliser les fonctions jpu.numpy lors de l'interaction avec des "quantités" avec des unités au lieu de l'interface jax.numpy . En effet, JAX ne fournit pas (encore?) Une interface générale pour la répartition des UFUNC sur les classes de tableau personnalisées. J'ai joué avec l'interface __jax_array__ sans papiers, mais ce n'est pas vraiment flexible, et il n'est actuellement pas compatible avec les objets Pytree.
Jusqu'à présent, seul un sous-ensemble de l'interface numpy / jax.numpy est implémenté. Les demandes de traction ajoutant un support plus large (y compris les sous-modules) seraient les bienvenues!