TextGan是基于生成对抗网络(GAN)的文本生成模型的Pytorch框架,包括一般文本生成模型和类别文本生成模型。 TextGan是一个基准平台,以支持基于GAN的文本生成模型的研究。由于大多数基于GAN的文本生成模型都是由TensorFlow实现的,因此TextGan可以帮助那些习惯于Pytorch更快地输入文本生成字段的人。
如果您在实施中发现任何错误,请告诉我!另外,如果您想添加其他型号,请随时为此存储库做出贡献。
要安装,请运行pip install -r requirements.txt 。如果出现CUDA问题,请咨询官方Pytorch入门指南。
下载稳定版本和UNZIP:http://kheaffield.com/code/kenlm.tar.gz
需要提升> = 1.42.0和BJAM
sudo apt-get install libboost-all-devbrew install boost; brew install bjam在Kenlm目录中运行:
mkdir -p build
cd build
cmake ..
make -j 4 pip install https://github.com/kpu/kenlm/archive/master.zip
有关KENLM的更多信息,请参见:https://github.com/kpu/kenlm和http://kheaffield.com/code/kenlm/
git clone https://github.com/williamSYSU/TextGAN-PyTorch.git
cd TextGAN-PyTorchImage COCO , EMNLP NEWs , Movie Review , Amazon Review )。 cd run
python3 run_[model_name].py 0 0 # The first 0 is job_id, the second 0 is gpu_id
# For example
python3 run_seqgan.py 0 0讲师
对于每个模型,整个运行过程均在instructor/oracle_data/seqgan_instructor.py中定义。 (例如,在合成数据实验中以Seqgan为例)。 init_model()和optimize()之类的一些基本功能是在instructor.py中的基类BasicInstructor中定义的。如果要添加新的基于GAN的文本生成模型,请在instructor/oracle_data下创建新的讲师,并为模型定义培训过程。
可视化
使用utils/visualization.py可视化日志文件,包括模型丢失和指标得分。在log_file_list中自定义日志文件,不超过len(color_list) 。日志文件名应排除.txt 。
记录
TextGan-Pytorch使用Python中的logging模块来记录运行过程,例如发电机的损失和度量分数。为了方便可视化,将分别保存两个相同的日志文件log/log_****_****.txt和save/**/log.txt 。此外,该代码将自动保存模型的状态和在./save/**/models和./save/models and ./save/**/samples中的批处理大小,每个日志步骤,其中**取决于您的hyper-parameters。
运行信号
您可以根据字典文件run_signal.txt轻松地使用类Signal (请参阅utils/helpers.py )来控制训练过程。
对于使用Signal ,只需编辑本地文件run_signal.txt ,然后将pre_sig设置为Fasle ,该程序将停止训练过程并逐步进入下一个训练阶段。如果您认为当前的培训足够,则很方便地停止培训。
自动选择GPU
在config.py中,该程序将自动在nvidia-smi中选择具有最小GPU-Util的GPU设备。默认情况下启用了此功能。如果要手动选择GPU设备,请在run_[run_model].py中取消点击--device args,并用命令指定GPU设备。
运行文件:run_seqgan.py
讲师:oracle_data,real_data
模型:生成器,歧视器
结构(来自Seqgan)

运行文件:run_leakgan.py
讲师:oracle_data,real_data
模型:生成器,歧视器
结构(来自Leakgan)

运行文件:run_maligan.py
讲师:oracle_data,real_data
模型:生成器,歧视器
结构(从我的理解中)

运行文件:run_jsdgan.py
讲师:oracle_data,real_data
模型:生成器(无歧视器)
结构(从我的理解中)

运行文件:run_relgan.py
讲师:oracle_data,real_data
模型:生成器,歧视器
结构(从我的理解中)

运行文件:run_dpgan.py
讲师:oracle_data,real_data
模型:生成器,歧视器
结构(来自DPGAN)

运行文件:run_dgsan.py
讲师:oracle_data,real_data
模型:生成器,歧视器
运行文件:run_cot.py
讲师:oracle_data,real_data
模型:生成器,歧视器
结构(来自COT)

运行文件:run_sentigan.py
讲师:oracle_data,real_data
模型:生成器,歧视器
结构(来自Sentigan)

运行文件:run_catgan.py
讲师:oracle_data,real_data
模型:生成器,歧视器
结构(来自Catgan)

麻省理工学院Lincense