
Tianshou (天授 天授) est une bibliothèque d'apprentissage par renforcement (RL) basée sur le pytorch pur et le gymnase. Les principales caractéristiques de Tianshou en un coup d'œil sont:
Contrairement à d'autres bibliothèques d'apprentissage par renforcement, qui peuvent avoir des bases de code complexes, des API de haut niveau hostiles, ou ne sont pas optimisées pour la vitesse, Tianshou fournit un cadre modularisé haute performance et des interfaces conviviales pour la construction d'agents d'apprentissage en renforcement profond. Un autre aspect qui distingue Tianshou est sa généralité: elle prend en charge les algorithmes en ligne et hors ligne, RL multi-agents et basés sur des modèles.
Tianshou vise à permettre des implémentations concises, tant pour les chercheurs et les praticiens, sans sacrifier la flexibilité.
Les algorithmes pris en charge comprennent:
Autres caractéristiques notables:
En chinois, Tianshou signifie divinement ordonné, dérivé du don de naître. Tianshou est une plate-forme d'apprentissage de renforcement, et la nature de la RL n'est pas apprendre des humains. Donc, prendre "Tianshou" signifie qu'il n'y a pas de professeur à apprendre, mais plutôt à apprendre par soi-même par une interaction constante avec l'environnement.
«天授» 意指上天所授 , 引申为与生具有的天赋。天授是强化学习平台 , 而强化学习算法并不是向人类学习的 , 所以取 所以取 所以取 天授 天授 天授 意思是没有老师来教 , 而是自己通过跟环境不断交互来进行学习。
Tianshou est actuellement hébergé sur PYPI et Conda-Forge. Il nécessite Python> = 3.11.
Pour installer la version la plus récente de Tianshou, la meilleure façon est de cloner le référentiel et de l'installer avec de la poésie (que vous devez d'abord installer sur votre système)
git clone [email protected]:thu-ml/tianshou.git
cd tianshou
poetry install Vous pouvez également installer les exigences Dev en ajoutant --with dev ou les extras pour dire Mujoco et accélération par Envpool en ajoutant --extras "mujoco envpool"
Si vous souhaitez installer plusieurs extras, assurez-vous de les inclure dans une seule commande. Les appels séquentiels vers poetry install --extras xxx écraseront les installations antérieures, ne laissant que les derniers extras spécifiés installés. Ou vous pouvez installer tous les extras suivants en ajoutant --all-extras .
Les extras disponibles sont:
atari (pour les environnements Atari)box2d (pour les environnements Box2D)classic_control (pour les environnements de contrôle classique (discret))mujoco (pour les environnements Mujoco)mujoco-py (pour les environnements Mujoco-Py hérités 1 )pybullet (pour les environnements Pybullet)robotics (pour les environnements de gymnase-robotique)vizdoom (pour les environnements vizdoom)envpool (pour l'intégration Envpool)argparse (afin de pouvoir exécuter les exemples d'API de haut niveau)Sinon, vous pouvez installer la dernière version de PYPI (actuellement loin derrière le maître) avec la commande suivante:
$ pip install tianshouSi vous utilisez Anaconda ou MiniConda, vous pouvez installer Tianshou de Conda-Forge:
$ conda install tianshou -c conda-forgeAlternativement à l'installation de la poésie, vous pouvez également installer la dernière version source via GitHub:
$ pip install git+https://github.com/thu-ml/tianshou.git@master --upgradeEnfin, vous pouvez vérifier l'installation via votre console Python comme suit:
import tianshou
print ( tianshou . __version__ )Si aucune erreur n'est signalée, vous avez réussi à installer Tianshou.
Les tutoriels et la documentation de l'API sont hébergés sur tianshou.readthedocs.io.
Trouvez des exemples de scripts dans le test / et les exemples / dossiers.
| Plate-forme RL | Documentation | Couverture de code | Type Indices | Dernière mise à jour |
|---|---|---|---|---|
| Basélines stables3 | ✔️ | |||
| Ray / rllib | ➖ (1) | ✔️ | ||
| Rotation | ||||
| Dopamine | ||||
| ACMÉ | ➖ (1) | ✔️ | ||
| Échantillon d'usine | ➖ | |||
| Tianshou | ✔️ |
(1): il a une intégration continue mais le taux de couverture n'est pas disponible
Tianshou est rigoureusement testé. Contrairement à d'autres plates-formes RL, nos tests incluent la procédure de formation complète de l'agent pour tous les algorithmes implémentés . Nos tests échoueraient une fois si l'un des agents n'avait pas atteint un niveau de performance cohérent sur des époques limitées. Nos tests assurent ainsi la reproductibilité. Consultez la page des actions GitHub pour plus de détails.
Les résultats de référence Atari et Mujoco peuvent être trouvés dans les exemples / atari / et exemples / mujoco / dossiers respectivement. Nos résultats Mujoco atteignent ou dépassent le niveau de performance de la plupart des repères existants.
Tous les algorithmes implémentent l'API très générale suivante:
__init__ : initialisez la politique;forward : calculer les actions basées sur des observations données;process_buffer : procéder le tampon initial, qui est utile pour certains algorithmes d'apprentissage hors ligneprocess_fn : données de prétraitement du tampon de relecture (puisque nous avons reformulé tous les algorithmes pour rejouer des algorithmes basés sur le tampon);learn : Apprenez d'un lot de données donné;post_process_fn : Mettez à jour le tampon de relecture à partir du processus d'apprentissage (par exemple, le tampon de relecture prioritaire doit mettre à jour le poids);update : l'interface principale de la formation, c'est-à-dire process_fn -> learn -> post_process_fn .La mise en œuvre de cette API suffit qu'un nouvel algorithme soit applicable au sein de Tianshou, rendant l'expérimentation avec de nouvelles approches particulièrement simples.
Tianshou fournit deux niveaux d'API:
Dans ce qui suit, considérons un exemple d'application en utilisant l'environnement de gymnase Cartpole . Nous appliquerons l'algorithme d'apprentissage du réseau Q (DQN) Deep Q à l'aide des deux API.
Pour commencer, nous avons besoin de quelques importations.
from tianshou . highlevel . config import SamplingConfig
from tianshou . highlevel . env import (
EnvFactoryRegistered ,
VectorEnvType ,
)
from tianshou . highlevel . experiment import DQNExperimentBuilder , ExperimentConfig
from tianshou . highlevel . params . policy_params import DQNParams
from tianshou . highlevel . trainer import (
EpochTestCallbackDQNSetEps ,
EpochTrainCallbackDQNSetEps ,
EpochStopCallbackRewardThreshold
) Dans l'API de haut niveau, la base d'une expérience RL est un ExperimentBuilder avec lequel nous pouvons construire l'expérience que nous cherchons ensuite à exécuter. Puisque nous voulons utiliser DQN, nous utilisons la spécialisation DQNExperimentBuilder . Les autres importations servent à fournir des options de configuration pour notre expérience.
L'API de haut niveau fournit une sémantique largement déclarative, c'est-à-dire que le code est presque exclusivement concerné par la configuration qui contrôle ce qu'il faut faire (plutôt que comment le faire).
experiment = (
DQNExperimentBuilder (
EnvFactoryRegistered ( task = "CartPole-v1" , train_seed = 0 , test_seed = 0 , venv_type = VectorEnvType . DUMMY ),
ExperimentConfig (
persistence_enabled = False ,
watch = True ,
watch_render = 1 / 35 ,
watch_num_episodes = 100 ,
),
SamplingConfig (
num_epochs = 10 ,
step_per_epoch = 10000 ,
batch_size = 64 ,
num_train_envs = 10 ,
num_test_envs = 100 ,
buffer_size = 20000 ,
step_per_collect = 10 ,
update_per_step = 1 / 10 ,
),
)
. with_dqn_params (
DQNParams (
lr = 1e-3 ,
discount_factor = 0.9 ,
estimation_step = 3 ,
target_update_freq = 320 ,
),
)
. with_model_factory_default ( hidden_sizes = ( 64 , 64 ))
. with_epoch_train_callback ( EpochTrainCallbackDQNSetEps ( 0.3 ))
. with_epoch_test_callback ( EpochTestCallbackDQNSetEps ( 0.0 ))
. with_epoch_stop_callback ( EpochStopCallbackRewardThreshold ( 195 ))
. build ()
)
experiment . run ()Le constructeur d'expérience prend trois arguments:
watch=True ) pour un certain nombre d'épisodes ( watch_num_episodes=100 ). Nous avons une persistance handicapée, car nous ne voulons pas enregistrer les journaux de formation, l'agent ou sa configuration pour une utilisation future.num_epochs=10 )step_per_epoch=10000 ). Chaque époque se compose d'une série d'étapes de collecte de données (de déploiement) et d'étapes de formation. Le paramètre step_per_collect contrôle la quantité de données collectées à chaque étape de collecte et après chaque étape de collecte, nous effectuons une étape de formation, appliquant une mise à jour basée sur le gradient basée sur un échantillon de données ( batch_size=64 ) tiré du tampon de données qui a été collectée. Pour plus de détails, consultez la documentation de SamplingConfig .Nous procédons ensuite à la configuration de certains des paramètres de l'algorithme DQN lui-même et du modèle de réseau neuronal que nous voulons utiliser. Un détail spécifique au DQN est l'utilisation de rappels pour configurer le paramètre Epsilon de l'algorithme pour l'exploration. Nous voulons utiliser une exploration aléatoire pendant les déploiements (rappel de train), mais nous ne le faisons pas lors de l'évaluation des performances de l'agent dans les environnements de test (rappel de test).
Trouvez le script dans des exemples / discrets / discrets_dqn_hl.py. Voici une course (avec le temps d'entraînement interrompu):

Trouvez de nombreuses applications supplémentaires de l'API de haut niveau dans les examples/ dossiers; Recherchez les scripts se terminant par _hl.py . Notez que la plupart de ces exemples nécessitent le package supplémentaire argparse (l'installez en ajoutant --extras argparse lors de l'appel de poésie).
Voyons maintenant un exemple analogue dans l'API procédurale. Trouvez le script complet dans des exemples / discrets / discrets_dqn.py.
Tout d'abord, importez des packages pertinents:
import gymnasium as gym
import torch
from torch . utils . tensorboard import SummaryWriter
import tianshou as tsDéfinissez certains hyper-paramètres:
task = 'CartPole-v1'
lr , epoch , batch_size = 1e-3 , 10 , 64
train_num , test_num = 10 , 100
gamma , n_step , target_freq = 0.9 , 3 , 320
buffer_size = 20000
eps_train , eps_test = 0.1 , 0.05
step_per_epoch , step_per_collect = 10000 , 10Initialisez l'enregistreur:
logger = ts . utils . TensorboardLogger ( SummaryWriter ( 'log/dqn' ))
# For other loggers, see https://tianshou.readthedocs.io/en/master/01_tutorials/05_logger.htmlFaire des environnements:
# You can also try SubprocVectorEnv, which will use parallelization
train_envs = ts . env . DummyVectorEnv ([ lambda : gym . make ( task ) for _ in range ( train_num )])
test_envs = ts . env . DummyVectorEnv ([ lambda : gym . make ( task ) for _ in range ( test_num )])Créez le réseau ainsi que son optimiseur:
from tianshou . utils . net . common import Net
# Note: You can easily define other networks.
# See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
env = gym . make ( task , render_mode = "human" )
state_shape = env . observation_space . shape or env . observation_space . n
action_shape = env . action_space . shape or env . action_space . n
net = Net ( state_shape = state_shape , action_shape = action_shape , hidden_sizes = [ 128 , 128 , 128 ])
optim = torch . optim . Adam ( net . parameters (), lr = lr )Configurer la politique et les collectionneurs:
policy = ts . policy . DQNPolicy (
model = net ,
optim = optim ,
discount_factor = gamma ,
action_space = env . action_space ,
estimation_step = n_step ,
target_update_freq = target_freq
)
train_collector = ts . data . Collector ( policy , train_envs , ts . data . VectorReplayBuffer ( buffer_size , train_num ), exploration_noise = True )
test_collector = ts . data . Collector ( policy , test_envs , exploration_noise = True ) # because DQN uses epsilon-greedy methodFormons-le:
result = ts . trainer . OffpolicyTrainer (
policy = policy ,
train_collector = train_collector ,
test_collector = test_collector ,
max_epoch = epoch ,
step_per_epoch = step_per_epoch ,
step_per_collect = step_per_collect ,
episode_per_test = test_num ,
batch_size = batch_size ,
update_per_step = 1 / step_per_collect ,
train_fn = lambda epoch , env_step : policy . set_eps ( eps_train ),
test_fn = lambda epoch , env_step : policy . set_eps ( eps_test ),
stop_fn = lambda mean_rewards : mean_rewards >= env . spec . reward_threshold ,
logger = logger ,
). run ()
print ( f"Finished training in { result . timing . total_time } seconds" ) Enregistrer / charger la politique formée (c'est exactement la même chose que le chargement d'une torch.nn.module ):
torch . save ( policy . state_dict (), 'dqn.pth' )
policy . load_state_dict ( torch . load ( 'dqn.pth' ))Regardez l'agent avec 35 ips:
policy . eval ()
policy . set_eps ( eps_test )
collector = ts . data . Collector ( policy , env , exploration_noise = True )
collector . collect ( n_episode = 1 , render = 1 / 35 )Inspectez les données enregistrées dans Tensorboard:
$ tensorboard --logdir log/dqnVeuillez lire la documentation pour une utilisation avancée.
Tianshou est toujours en cours de développement. D'autres algorithmes et fonctionnalités sont continuellement ajoutés, et nous accueillons toujours les contributions pour aider à améliorer Tianshou. Si vous souhaitez contribuer, veuillez consulter ce lien.
Si vous trouvez Tianshou utile, veuillez le citer dans vos publications.
@article{tianshou,
author = {Jiayi Weng and Huayu Chen and Dong Yan and Kaichao You and Alexis Duburcq and Minghao Zhang and Yi Su and Hang Su and Jun Zhu},
title = {Tianshou: A Highly Modularized Deep Reinforcement Learning Library},
journal = {Journal of Machine Learning Research},
year = {2022},
volume = {23},
number = {267},
pages = {1--6},
url = {http://jmlr.org/papers/v23/21-1127.html}
}Tianshou est soutenu par l'Appliedai Institute for Europe, qui s'engage à fournir un soutien et un développement à long terme.
Tianshou était auparavant une plate-forme d'apprentissage de renforcement basée sur TensorFlow. Vous pouvez consulter la branche priv pour plus de détails. Un grand merci au travail pionnier de Haosheng Zou pour Tianshou avant la version 0.1.1.
Nous tenons à remercier Tsail et Institute pour l'intelligence artificielle de l'Université Tsinghua d'avoir fourni une excellente plate-forme de recherche sur l'IA.
mujoco-py est un package hérité et n'est pas recommandé pour les nouveaux projets. Il n'est inclus que pour la compatibilité avec les projets plus anciens. Notez également qu'il peut y avoir des problèmes de compatibilité avec MacOS plus récent que Monterey. ↩