thunlp-aipoet / mixpoet Goto Github PK
View Code? Open in Web Editor NEWSource codes of MixPoet: Diverse Poetry Generation via Learning Controllable Mixed Latent Space (AAAI 2020)
Source codes of MixPoet: Diverse Poetry Generation via Learning Controllable Mixed Latent Space (AAAI 2020)
作者您好~我用这个模型进行了修改,用于数据量和风格更多的文本生成任务。模型结构主要是吧Encoder和Decoder置换为Transformer结构,其它没有变更。但是似乎很难收敛。不知道作者有尝试过Transformer结构的Mixpoet吗?
此外对于源论文,我有一些疑问想请教下:
1:隐层z的信息,是仅包含Factor信息,还是包含了Factor和Context信息?如果是后者,对抗训练如何保证隐层z对于Prior可以学到Context信息呢?换句话说,判别器似乎只能通过已知的Factor标签来判别Prior和Poesterior输出的z在factor信息上的准确度。
2:生成器同时优化Poesterior和Prior,是否可能会使得Poesterior蕴含的Context信息退化?因为只需要两个输出的z都是无效信息,就可以完美的欺骗判别器了。(这个问题在训练过程中有体现,只有锁住生成器的Poesterior更新梯度才能保留Poesterior输出的z所包含的context信息)
3:如果z已经包含了足够的Factor和Context信息,最后输出给Decoder的时候为什么需要把w(key_word)和y(factor label)信息与z进行拼接了?因为这样做就无法证明Decoder中到底是z在起作用,还是k和y在起作用了。
期待您的解答 ^-^
AttributeError: 'HParams' object has no attribute 'to' 请问怎么解决呀
模型训练完成,执行python generate.py -v 1时出错:
python generate.py -v 1
vocabulary size: 7024
restore checkpoint from ../checkpoint/model_ckpt_mixfine_3e.tar
loading...
load state dic, params: 141...
loading poetry filter...
rhythm dic loaded, level tone chars: 3304, oblique tone chars: 4301
rhyme dic loaded, ambiguous rhyme chars: 67
Traceback (most recent call last):
File "generate.py", line 122, in
main()
File "generate.py", line 116, in main
generate_manu(args)
File "generate.py", line 38, in generate_manu
generator = Generator()
File "D:\AI\九歌\MixPoet-master\codes\generator.py", line 54, in init
self.tool.get_ivocab(), self.hps.data_dir)
File "D:\AI\九歌\MixPoet-master\codes\filter.py", line 43, in init
self.__load_rhyme_dic(data_dir+"pingshui.txt", data_dir+"pingshui_amb.pkl")
File "D:\AI\九歌\MixPoet-master\codes\filter.py", line 119, in __load_rhyme_dic
self.__char_rhyme_map = pickle.load(fin)
_pickle.UnpicklingError: invalid load key, 'v'.
出错文件为pingshui_amb.pkl
请问各位老师是什么原因,如何解决,谢谢
python train.py
hyper-patameters:
HParams(vocab_size=7024, pad_idx=0, bos_idx=3, emb_size=256, hidden_size=512, context_size=512, latent_size=256, factor_emb_size=64, n_class1=3, n_class2=2, key_len=4, sens_num=4, sen_len=9, poem_len=30, batch_size=128, drop_ratio=0.15, weight_decay=0.00025, clip_grad_norm=2.0, max_lr=0.0008, min_lr=5e-08, warmup_steps=6000, ndis=3, min_tr=0.85, burn_down_tr=3, decay_tr=6, tau_annealing_steps=6000, min_tau=0.01, rec_warm_steps=1500, noise_decay_steps=8500, log_steps=200, sample_num=1, max_epoches=12, save_epoches=3, validate_epoches=1, fbatch_size=64, fmax_epoches=3, fsave_epoches=1, vocab_path='../corpus/vocab.pickle', ivocab_path='../corpus/ivocab.pickle', train_data='../corpus/semi_train.pickle', valid_data='../corpus/semi_valid.pickle', model_dir='../checkpoint/', data_dir='../data/', train_log_path='../log/mix_train_log.txt', valid_log_path='../log/mix_valid_log.txt', fig_log_path='../log/', corrupt_ratio=0.1, dae_epoches=10, dae_batch_size=128, dae_max_lr=0.0008, dae_min_lr=5e-08, dae_warmup_steps=4500, dae_min_tr=0.85, dae_burn_down_tr=2, dae_decay_tr=6, dae_log_steps=300, dae_validate_epoches=1, dae_save_epoches=2, dae_train_log_path='../log/dae_train_log.txt', dae_valid_log_path='../log/dae_valid_log.txt', cl_batch_size=64, cl_epoches=10, cl_max_lr=0.0008, cl_min_lr=5e-08, cl_warmup_steps=800, cl_log_steps=100, cl_validate_epoches=1, cl_save_epoches=2, cl_train_log_path='../log/cl_train_log.txt', cl_valid_log_path='../log/cl_valid_log.txt')
please check the hyper-parameters, and then press any key to continue >ok
ok
dae pretraining...
layers.embed.weight torch.Size([7024, 256])
layers.encoder.rnn.weight_ih_l0 torch.Size([1536, 256])
layers.encoder.rnn.weight_hh_l0 torch.Size([1536, 512])
layers.encoder.rnn.bias_ih_l0 torch.Size([1536])
layers.encoder.rnn.bias_hh_l0 torch.Size([1536])
layers.encoder.rnn.weight_ih_l0_reverse torch.Size([1536, 256])
layers.encoder.rnn.weight_hh_l0_reverse torch.Size([1536, 512])
layers.encoder.rnn.bias_ih_l0_reverse torch.Size([1536])
layers.encoder.rnn.bias_hh_l0_reverse torch.Size([1536])
layers.decoder.rnn.weight_ih_l0 torch.Size([1536, 512])
layers.decoder.rnn.weight_hh_l0 torch.Size([1536, 512])
layers.decoder.rnn.bias_ih_l0 torch.Size([1536])
layers.decoder.rnn.bias_hh_l0 torch.Size([1536])
layers.word_encoder.rnn.weight_ih_l0 torch.Size([256, 256])
layers.word_encoder.rnn.weight_hh_l0 torch.Size([256, 256])
layers.word_encoder.rnn.bias_ih_l0 torch.Size([256])
layers.word_encoder.rnn.bias_hh_l0 torch.Size([256])
layers.word_encoder.rnn.weight_ih_l0_reverse torch.Size([256, 256])
layers.word_encoder.rnn.weight_hh_l0_reverse torch.Size([256, 256])
layers.word_encoder.rnn.bias_ih_l0_reverse torch.Size([256])
layers.word_encoder.rnn.bias_hh_l0_reverse torch.Size([256])
layers.out_proj.weight torch.Size([7024, 512])
layers.out_proj.bias torch.Size([7024])
layers.map_x.mlp.linear_0.weight torch.Size([512, 768])
layers.map_x.mlp.linear_0.bias torch.Size([512])
layers.context.conv.weight torch.Size([512, 512, 3])
layers.context.conv.bias torch.Size([512])
layers.context.linear.weight torch.Size([512, 1024])
layers.context.linear.bias torch.Size([512])
layers.dec_init_pre.mlp.linear_0.weight torch.Size([506, 1536])
layers.dec_init_pre.mlp.linear_0.bias torch.Size([506])
params num: 31
building data for dae...
193461
34210
train batch num: 1512
valid batch num: 268
Traceback (most recent call last):
File "train.py", line 79, in
main()
File "train.py", line 74, in main
pretrain(mixpoet, tool, hps)
File "train.py", line 30, in pretrain
dae_trainer.train(mixpoet, tool)
File "/content/MixPoet/codes/dae_trainer.py", line 137, in train
self.run_train(mixpoet, tool, optimizer, logger)
File "/content/MixPoet/codes/dae_trainer.py", line 83, in run_train
batch_keys, batch_poems, batch_dec_inps, batch_lengths)
File "/content/MixPoet/codes/dae_trainer.py", line 58, in run_step
mixpoet.dae_graph(keys, poems, dec_inps, lengths)
File "/content/MixPoet/codes/graphs.py", line 337, in dae_graph
_, poem_state0 = self.computer_enc(poems, self.layers['encoder'])
File "/content/MixPoet/codes/graphs.py", line 232, in computer_enc
enc_outs, enc_state = encoder(emb_inps, lengths)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/content/MixPoet/codes/layers.py", line 59, in forward
input_lens, batch_first=True, enforce_sorted=False)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/utils/rnn.py", line 249, in pack_padded_sequence
_VF._pack_padded_sequence(input, lengths, batch_first)
RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor, but got 1D cuda:0 Long tensor
改造成web版后,发现一个内存泄露问题
在codes/beam.py里面的np.argpartition 这个方法
加上.copy()
网上分析:https://www.nuomiphp.com/eplan/19414.html
加上之后,没有明显内存泄露
'ccpt_train.json' is not found.
您好,
我在测试加载“./data/pingshui_amb.pkl”时出现错误,
报无法找到pingshui_amb.pkl文件,
后发现这个是大文件,使用git lfs后报如下错误:
Git LFS: (0 of 1 files) 0 B / 115.47 MB
batch response: This repository is over its data quota. Account responsible for LFS bandwidth should purchase more data packs to restore access.
error: failed to fetch some objects from 'https://github.com/THUNLP-AIPoet/Datasets.git/info/lfs'
希望能够帮忙看一下这个问题,谢谢!
我按照原始代码的设定训练了12个epoch,效果不怎样。
layers.py line 480,Is there a mistake?
entropy = torch.log(probs+1e-10) * probs # (B, n_class)
# should it be : ?
entropy = - torch.log(probs+1e-10) * probs # (B, n_class)
Traceback (most recent call last):
File "train.py", line 80, in
main()
File "train.py", line 76, in main
train(mixpoet, tool, hps)
File "train.py", line 53, in train
mix_trainer.train(mixpoet, tool)
File "C:\Users\yingf\Desktop\MixPoet\codes\mix_trainer.py", line 304, in train
File "C:\Users\yingf\Desktop\MixPoet\codes\mix_trainer.py", line 247, in run_train
logger.draw_curves()
File "C:\Users\yingf\Desktop\MixPoet\codes\logger.py", line 422, in draw_curves
fig.savefig(self.fig_path+"/latent_distance.png", dpi=300, quality=100, bbox_inches="tight")
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\site-packages\matplotlib\figure.py", line 3274, in savefig
self.canvas.print_figure(fname, **kwargs)
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\site-packages\matplotlib\backend_bases.py", line 2338, in print_figure
result = print_method(
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\site-packages\matplotlib\backend_bases.py", line 2204, in
print_method = functools.wraps(meth)(lambda *args, **kwargs: meth(
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\site-packages\matplotlib_api\deprecation.py", line 385, in wrapper
arguments = signature.bind(*inner_args, **inner_kwargs).arguments
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\inspect.py", line 3037, in bind
return self._bind(args, kwargs)
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\inspect.py", line 3026, in _bind
raise TypeError(
TypeError: got an unexpected keyword argument 'quality'
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.