seqGAN
1.0.0
“ seqgan:序列生成的对抗网的Pytorch实施,具有策略梯度”。 (Yu,Lantao等)。该代码被高度简化,评论和(希望)直接理解。实施的策略梯度也比原始工作(https://github.com/lantaoyu/seqgan/)简单得多,并且不涉及整个句子的单一奖励(灵感来自http://karpathy.github.io.github.io/2016/05/05/31/rl/)。
所使用的体系结构与Orignal工作中的架构不同。具体而言,将复发性双向GRU网络用作歧视者。
如本文所述,该代码对合成数据执行实验。
鼓励您提出对代码工作作为问题的任何疑问。
运行代码:
python main.pyMain.py应该是您进入代码的切入点。
在这种情况下,以下hacks(从https://github.com/soumith/ganhacks借用)似乎有效:
训练判别器比发电机要多得多(生成器仅针对一批示例进行训练,并且增加批量尺寸会损害稳定性)
使用Adam进行发电机和Adagrad进行歧视者
在GAN阶段调整发电机的学习率
在训练和测试阶段使用辍学
稳定性几乎对每个参数非常敏感:/
GAN阶段可能并不总是会导致NLL大量下降(有时很少) - 我怀疑这是由于实施的政策梯度的非常粗略的性质(没有推出)。
MLE训练后获得100个时期的学习曲线,然后进行对抗训练。 (您的结果可能会有所不同!)