relational networks
1.0.0
关系网络的Pytorch实施 - 一个简单的神经网络模块用于关系推理
实施并在类似的CLEVR任务上进行了测试。
Clevr的简化版本。该版本由每个图像由10000张图像和20个问题(10个关系问题和10个非关系问题)组成。将6种颜色(红色,绿色,蓝色,橙色,灰色,黄色)分配为随机选择的形状(正方形或圆形),并将其放置在图像中。
非关系问题由3个亚型组成:
这些问题是“非关系”,因为代理只需要专注于某些对象。
关系问题由3个亚型组成:
这些问题是“关系”,因为代理必须考虑对象之间的关系。
问题被编码为一种尺寸为11:6的矢量,对于6种颜色的某些颜色的一速矢量,2对于关系/非相关问题的一速矢量为2。 3对于3个亚型的一个速率向量。

即,在显示的示例图像中,我们可以产生非关系问题,例如:
和关系问题:
通过environment.yml创建conda环境。
$ conda env create -f environment.yml
激活环境
$ conda activate RN3
如果您不使用conda安装python 3并使用pip install来安装剩余的依赖项。依赖项列表可以在environment.yml中找到。
$ ./run.sh
或者
$ python sort_of_clevr_generator.py
生成clevr数据集和
$ python main.py
训练二进制RN模型。或者,使用
$ python main.py --relation-type=ternary
训练三元RN模型。
在原始论文中,类似的clevr任务使用了CLEVR任务的不同模型。但是,由于使用CLEVR的模型所需的计算时间要少得多(网络要小得多),因此该模型用于CLEVR任务。
| 关系网络(第20个时代) | CNN + MLP(无RN,第100个时期) | |
|---|---|---|
| 非关系问题 | 99% | 66% |
| 关系问题 | 89% | 66% |
CNN + MLP对培训数据过于拟合。
关系网络在关系问题和非关系问题中显示出更好的结果。
@gngdb将模型加快了10次。