Thomas Schmied 1 ، Markus Hofmarcher 2 ، Fabian Paischer 1 ، Razvan Pacscanu 3،4 ، Sepp Hochreiter 1،5
1 وحدة Ellis Linz و Lit AI Lab ، معهد التعلم الآلي ، جامعة يوهانس كيبلر لينز ، النمسا
2 JKU LIT SAL ESPML Lab ، معهد التعلم الآلي ، جامعة جوهانس كيبلر لينز ، النمسا
3 Google DeepMind
4 UCL
5 معهد الأبحاث المتقدمة في الذكاء الاصطناعي (IARAI) ، فيينا ، النمسا
يحتوي هذا المستودع على الكود المصدري لـ "تعلم تعديل النماذج التي تم تدريبها مسبقًا في RL" المقبولة في Neurips 2023. الورقة متوفرة هنا.

تدعم قاعدة بيانات الكود هذه نماذج محولات القرار التدريبية (DT) عبر الإنترنت أو من مجموعات البيانات غير المتصل بالإنترنت في المجالات التالية:
تعتمد قاعدة قاعدة الشفرة هذه على أطر عمل مفتوحة المصدر ، بما في ذلك:
ماذا يوجد في هذا المستودع؟
.
├── configs # Contains all .yaml config files for Hydra to configure agents, envs, etc.
│ ├── agent_params
│ ├── wandb_callback_params
│ ├── env_params
│ ├── eval_params
│ ├── run_params
│ └── config.yaml # Main config file for Hydra - specifies log/data/model directories.
├── continual_world # Submodule for Continual-World.
├── dmc2gym_custom # Custom wrapper for DMControl.
├── figures
├── scripts # Scrips for running experiments on Slurm/PBS in multi-gpu/node setups.
├── src # Main source directory.
│ ├── algos # Contains agent/model/prompt classes.
│ ├── augmentations # Image augmentations.
│ ├── buffers # Contains replay trajectory buffers.
│ ├── callbacks # Contains callbacks for training (e.g., WandB, evaluation, etc.).
│ ├── data # Contains data utilities (e.g., for downloading Atari)
│ ├── envs # Contains functionality for creating environments.
│ ├── exploration # Contains exploration strategies.
│ ├── optimizers # Contains (custom) optimizers.
│ ├── schedulers # Contains learning rate schedulers.
│ ├── tokenizers_custom # Contains custom tokenizers for discretizing states/actions.
│ ├── utils
│ └── __init__.py
├── LICENSE
├── README.md
├── environment.yaml
├── requirements.txt
└── main.py # Main entry point for training/evaluating agents.
تتوفر تكوين البيئة وتبعياتها في environment.yaml و requirements.txt .
أولا ، إنشاء بيئة كوندا.
conda env create -f environment.yaml
conda activate mddt
ثم قم بتثبيت المتطلبات المتبقية (مع تنزيل Mujoco بالفعل ، إن لم يكن ترى هنا):
pip install -r requirements.txt
inter the continualworld Submodule وتثبيت:
git submodule init
git submodule update
cd continualworld
pip install .
تثبيت meta-world :
pip install git+https://github.com/rlworkgroup/metaworld.git@18118a28c06893da0f363786696cc792457b062b
تثبيت إصدار مخصص من DMC2Gym. إصدارنا يجعل flatten_obs اختياريًا ، وبالتالي يتيح لنا بناء مساحة المراقبة الكاملة لجميع ENCS DMControl.
cd dmc2gym_custom
pip install -e .
تحميل mujoco:
mkdir ~/.mujoco
cd ~/.mujoco
wget https://www.roboti.us/download/mujoco200_linux.zip
unzip mujoco200_linux.zip
mv mujoco200_linux mujoco200
wget https://www.roboti.us/file/mjkey.txt
ثم أضف السطر التالي إلى .bashrc :
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin
كانت القضايا التالية مفيدة:
أولاً ، قم بتثبيت الحزم التالية:
conda install -c conda-forge glew mesalib
conda install -c menpo glfw3 osmesa
pip install patchelf
قم بإنشاء Symlink يدويًا:
cp /usr/lib64/libGL.so.1 $CONDA_PREFIX/lib
ln -s $CONDA_PREFIX/lib/libGL.so.1 $CONDA_PREFIX/lib/libGL.so
ثم افعل:
mkdir ~/rpm
cd ~/rpm
curl -o libgcrypt11.rpm ftp://ftp.pbone.net/mirror/ftp5.gwdg.de/pub/opensuse/repositories/home:/bosconovic:/branches:/home:/elimat:/lsi/openSUSE_Leap_15.1/x86_64/libgcrypt11-1.5.4-lp151.23.29.x86_64.rpm
rpm2cpio libgcrypt11.rpm | cpio -id
أخيرًا ، قم بتصدير المسار إلى rpm Dir (أضف إلى ~/.bashrc ):
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/rpm/usr/lib64
export LDFLAGS="-L/~/rpm/usr/lib64"
تعتمد قاعدة الشفرة هذه على Hydra ، والتي تقوم بتكوين التجارب عبر ملفات .yaml . تقوم Hydra تلقائيًا بإنشاء بنية مجلد السجل لتشغيل معين ، كما هو محدد في ملف config.yaml المعني.
config.yaml هي نقطة إدخال التكوين الرئيسية وتحتوي على المعلمات الافتراضية. يشير الملف إلى ملفات المعلمة الافتراضية ذات الصلة تحت defaults كتلة. بالإضافة إلى ذلك ، يحتوي config.yaml على 4 ثوابت مهمة تقوم بتكوين مسارات الدليل:
LOG_DIR: ../logs
DATA_DIR: ../data
SSD_DATA_DIR: ../data
MODELS_DIR: ../models
يتم استضافة مجموعات البيانات التي يتم توليدها حاليًا عبر خادم الويب الخاص بنا. قم بتنزيل مجموعات بيانات Meta-World و DMControl إلى DATA_DIR المحددة:
# Meta-World
wget --recursive --no-parent --no-host-directories --cut-dirs=2 -R "index.html*" https://ml.jku.at/research/l2m/metaworld
# DMControl
wget --recursive --no-parent --no-host-directories --cut-dirs=2 -R "index.html*" https://ml.jku.at/research/l2m/dm_control_1M
تتوفر مجموعات البيانات أيضًا على Huggingface Hub. تنزيل باستخدام huggingface-cli :
# Meta-World
huggingface-cli download ml-jku/meta-world --local-dir=./meta-world --repo-type dataset
# DMControl
huggingface-cli download ml-jku/dm_control --local-dir=./dm_control --repo-type dataset
يدعم الإطار أيضًا مجموعات بيانات Atari و D4RL و DMControl المرئية. بالنسبة إلى Atari و DMControl المرئي ، نشير إلى readmes المعنية.
في ما يلي ، نقدم بعض الأمثلة التوضيحية حول كيفية إجراء التجارب في الورقة.
لتدريب نموذج محول القرار متعدد المجالات (MDDT) على ارتفاع 40 مترًا على MT40 + DMC10 مع 3 بذور على وحدة معالجة الرسومات الواحدة ، قم بتشغيل:
python main.py -m experiment_name=pretrain seed=42,43,44 env_params=multi_domain_mtdmc run_params=pretrain eval_params=pretrain_disc agent_params=cdt_pretrain_disc agent_params.kind=MDDT agent_params/model_kwargs=multi_domain_mtdmc agent_params/data_paths=mt40v2_dmc10 +agent_params/replay_buffer_kwargs=multi_domain_mtdmc +agent_params.accumulation_steps=2
لضبط النموذج الذي تم تدريبه مسبقًا باستخدام Lora على مهمة CW10 واحدة مع 3 بذور ، قم بتشغيل:
python main.py -m experiment_name=cw10_lora seed=42,43,44 env_params=mt50_pretrain run_params=finetune eval_params=finetune agent_params=cdt_mpdt_disc agent_params/model_kwargs=mdmpdt_mtdmc agent_params/data_paths=cw10_v2_cwnet_2M +agent_params/replay_buffer_kwargs=mtdmc_ft agent_params/model_kwargs/prompt_kwargs=lora env_params.envid=hammer-v2 agent_params.data_paths.names='${env_params.envid}.pkl' env_params.eval_env_names=
لضبط النموذج الذي تم تدريبه مسبقًا باستخدام L2M على جميع مهام CW10 بطريقة متتابعة مع 3 بذور ، قم بتشغيل:
python main.py -m experiment_name=cw10_cl_l2m seed=42,43,44 env_params=multi_domain_ft env_params.eval_env_names=cw10_v2 run_params=finetune_coff eval_params=finetune_md_cl agent_params=cdt_mpdt_disc +agent_params.steps_per_task=100000 agent_params/model_kwargs=mdmpdt_mtdmc agent_params/data_paths=cw10_v2_cwnet_2M +agent_params/replay_buffer_kwargs=mtdmc_ft +agent_params.replay_buffer_kwargs.kind=continual agent_params/model_kwargs/prompt_kwargs=l2m_lora
للتدريب متعدد GPU ، نستخدم torchrun . تتعارض الأداة مع hydra . لذلك ، تم إنشاء مكون إضافي للقاذفة Hydra_torchrun_launcher.
لتمكين المكون الإضافي ، استنساخ repo hydra ، القرص المضغوط contrib/hydra_torchrun_launcher ، و pip تثبيت المكون الإضافي:
git clone https://github.com/facebookresearch/hydra.git
cd hydra/contrib/hydra_torchrun_launcher
pip install -e .
يمكن استخدام البرنامج المساعد من سطر الأوامر:
python main.py -m hydra/launcher=torchrun hydra.launcher.nproc_per_node=4 [...]
يمكن إجراء تجارب على مجموعة محلية على عقدة واحدة عبر CUDA_VISIBLE_DEVICES لتحديد وحدات معالجة الرسومات لاستخدامها:
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py -m hydra/launcher=torchrun hydra.launcher.nproc_per_node=4 [...]
على Slurm ، تنفيذ torchrun على عقدة واحدة على حد سواء. على سبيل المثال ، لتشغيل 2 وحدات معالجة الرسومات على عقدة واحدة:
#!/bin/bash
#SBATCH --account=X
#SBATCH --qos=X
#SBATCH --partition=X
#SBATCH --nodes=1
#SBATCH --gpus=2
#SBATCH --cpus-per-task=32
source activate mddt
python main.py -m hydra/launcher=torchrun hydra.launcher.nproc_per_node=2 [...]
مثال البرامج النصية للتدريب متعدد GPU على SLURM أو PBS متوفرة في scripts .
يتطلب الركض على SLURM/PBS في إعداد متعدد العقدة رعاية أكثر بقليل. يتم توفير نصوص مثال في scripts .
إذا وجدت هذا مفيدًا ، يرجى التفكير في الاستشهاد بعملنا:
@article{schmied2024learning,
title={Learning to Modulate pre-trained Models in RL},
author={Schmied, Thomas and Hofmarcher, Markus and Paischer, Fabian and Pascanu, Razvan and Hochreiter, Sepp},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2024}
}