sgan
1.0.0
这是纸的代码
社会甘:具有生成对抗网络的社会可接受的轨迹
Agrim Gupta,Justin Johnson,Fei-Fei Li,Silvio Savarese,Alexandre Alahi
在CVPR 2018上发表
人类运动是人际关系,多模式,并遵循社会惯例。在本文中,我们通过结合序列预测和生成对抗网络的工具来解决此问题:一个经常性的序列到序列模型观察运动历史并预测未来的行为,并使用一种新颖的合并机制来汇总人之间的信息。
下面我们显示了在复杂场景中我们的模型做出的社会可接受预测的示例。每个人都用不同的颜色表示。我们表示通过点观察到的轨迹,并通过恒星预测轨迹。


如果您发现此代码对您的研究有用,请引用
@inproceedings{gupta2018social,
title={Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks},
author={Gupta, Agrim and Johnson, Justin and Fei-Fei, Li and Savarese, Silvio and Alahi, Alexandre},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
number={CONF},
year={2018}
}
我们的模型由三个关键组件组成:生成器(G),池模块(PM)和歧视器(D)。 G基于编码器框架,在该框架中,我们通过PM链接编码器和解码器的隐藏状态。 G将所有涉及场景的人的输入轨迹和输出相应的预测轨迹作为输入轨迹。 d输入整个序列包括输入轨迹和未来预测,并将它们归类为“真实/假货”。

所有代码均在Ubuntu 16.04上使用Python 3.5和Pytorch 0.4开发和测试。
您可以设置虚拟环境以运行这样的代码:
python3 -m venv env # Create a virtual environment
source env/bin/activate # Activate virtual environment
pip install -r requirements.txt # Install dependencies
echo $PWD > env/lib/python3.5/site-packages/sgan.pth # Add current directory to python path
# Work for a while ...
deactivate # Exit virtual environment 您可以通过运行脚本bash scripts/download_models.sh下载验证的模型。这将下载以下模型:
sgan-models/<dataset_name>_<pred_len>.pt :包含所有五个数据集的10个审慎模型。这些模型对应于表1中的SGAN-20V-20。sgan-p-models/<dataset_name>_<pred_len>.pt :包含所有五个数据集的10个预读模型。这些模型对应于表1中的SGAN-20VP-20。请参阅模型动物园以获取结果。
您可以使用脚本scripts/evaluate_model.py来轻松运行任何数据集上的任何预处理模型。例如,您可以为SGAN-20V-20的所有数据集复制表1结果:
python scripts/evaluate_model.py
--model_path models/sgan-models可以在此处找到培训新模型的说明。