Coder Social home page Coder Social logo

mixpoet's People

Contributors

mtmoon avatar xiaoyuanyi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

mixpoet's Issues

模型收敛问题

作者您好~我用这个模型进行了修改,用于数据量和风格更多的文本生成任务。模型结构主要是吧Encoder和Decoder置换为Transformer结构,其它没有变更。但是似乎很难收敛。不知道作者有尝试过Transformer结构的Mixpoet吗?

此外对于源论文,我有一些疑问想请教下:
1:隐层z的信息,是仅包含Factor信息,还是包含了Factor和Context信息?如果是后者,对抗训练如何保证隐层z对于Prior可以学到Context信息呢?换句话说,判别器似乎只能通过已知的Factor标签来判别Prior和Poesterior输出的z在factor信息上的准确度。
2:生成器同时优化Poesterior和Prior,是否可能会使得Poesterior蕴含的Context信息退化?因为只需要两个输出的z都是无效信息,就可以完美的欺骗判别器了。(这个问题在训练过程中有体现,只有锁住生成器的Poesterior更新梯度才能保留Poesterior输出的z所包含的context信息)
3:如果z已经包含了足够的Factor和Context信息,最后输出给Decoder的时候为什么需要把w(key_word)和y(factor label)信息与z进行拼接了?因为这样做就无法证明Decoder中到底是z在起作用,还是k和y在起作用了。

期待您的解答 ^-^

bug

AttributeError: 'HParams' object has no attribute 'to' 请问怎么解决呀

pickle数据加载错:_pickle.UnpicklingError: invalid load key, 'v'

模型训练完成,执行python generate.py -v 1时出错:
python generate.py -v 1
vocabulary size: 7024
restore checkpoint from ../checkpoint/model_ckpt_mixfine_3e.tar
loading...
load state dic, params: 141...
loading poetry filter...
rhythm dic loaded, level tone chars: 3304, oblique tone chars: 4301
rhyme dic loaded, ambiguous rhyme chars: 67
Traceback (most recent call last):
File "generate.py", line 122, in
main()
File "generate.py", line 116, in main
generate_manu(args)
File "generate.py", line 38, in generate_manu
generator = Generator()
File "D:\AI\九歌\MixPoet-master\codes\generator.py", line 54, in init
self.tool.get_ivocab(), self.hps.data_dir)
File "D:\AI\九歌\MixPoet-master\codes\filter.py", line 43, in init
self.__load_rhyme_dic(data_dir+"pingshui.txt", data_dir+"pingshui_amb.pkl")
File "D:\AI\九歌\MixPoet-master\codes\filter.py", line 119, in __load_rhyme_dic
self.__char_rhyme_map = pickle.load(fin)
_pickle.UnpicklingError: invalid load key, 'v'.
出错文件为pingshui_amb.pkl
请问各位老师是什么原因,如何解决,谢谢

RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor, but got 1D cuda:0 Long tensor

python train.py

hyper-patameters:
HParams(vocab_size=7024, pad_idx=0, bos_idx=3, emb_size=256, hidden_size=512, context_size=512, latent_size=256, factor_emb_size=64, n_class1=3, n_class2=2, key_len=4, sens_num=4, sen_len=9, poem_len=30, batch_size=128, drop_ratio=0.15, weight_decay=0.00025, clip_grad_norm=2.0, max_lr=0.0008, min_lr=5e-08, warmup_steps=6000, ndis=3, min_tr=0.85, burn_down_tr=3, decay_tr=6, tau_annealing_steps=6000, min_tau=0.01, rec_warm_steps=1500, noise_decay_steps=8500, log_steps=200, sample_num=1, max_epoches=12, save_epoches=3, validate_epoches=1, fbatch_size=64, fmax_epoches=3, fsave_epoches=1, vocab_path='../corpus/vocab.pickle', ivocab_path='../corpus/ivocab.pickle', train_data='../corpus/semi_train.pickle', valid_data='../corpus/semi_valid.pickle', model_dir='../checkpoint/', data_dir='../data/', train_log_path='../log/mix_train_log.txt', valid_log_path='../log/mix_valid_log.txt', fig_log_path='../log/', corrupt_ratio=0.1, dae_epoches=10, dae_batch_size=128, dae_max_lr=0.0008, dae_min_lr=5e-08, dae_warmup_steps=4500, dae_min_tr=0.85, dae_burn_down_tr=2, dae_decay_tr=6, dae_log_steps=300, dae_validate_epoches=1, dae_save_epoches=2, dae_train_log_path='../log/dae_train_log.txt', dae_valid_log_path='../log/dae_valid_log.txt', cl_batch_size=64, cl_epoches=10, cl_max_lr=0.0008, cl_min_lr=5e-08, cl_warmup_steps=800, cl_log_steps=100, cl_validate_epoches=1, cl_save_epoches=2, cl_train_log_path='../log/cl_train_log.txt', cl_valid_log_path='../log/cl_valid_log.txt')
please check the hyper-parameters, and then press any key to continue >ok
ok
dae pretraining...
layers.embed.weight torch.Size([7024, 256])
layers.encoder.rnn.weight_ih_l0 torch.Size([1536, 256])
layers.encoder.rnn.weight_hh_l0 torch.Size([1536, 512])
layers.encoder.rnn.bias_ih_l0 torch.Size([1536])
layers.encoder.rnn.bias_hh_l0 torch.Size([1536])
layers.encoder.rnn.weight_ih_l0_reverse torch.Size([1536, 256])
layers.encoder.rnn.weight_hh_l0_reverse torch.Size([1536, 512])
layers.encoder.rnn.bias_ih_l0_reverse torch.Size([1536])
layers.encoder.rnn.bias_hh_l0_reverse torch.Size([1536])
layers.decoder.rnn.weight_ih_l0 torch.Size([1536, 512])
layers.decoder.rnn.weight_hh_l0 torch.Size([1536, 512])
layers.decoder.rnn.bias_ih_l0 torch.Size([1536])
layers.decoder.rnn.bias_hh_l0 torch.Size([1536])
layers.word_encoder.rnn.weight_ih_l0 torch.Size([256, 256])
layers.word_encoder.rnn.weight_hh_l0 torch.Size([256, 256])
layers.word_encoder.rnn.bias_ih_l0 torch.Size([256])
layers.word_encoder.rnn.bias_hh_l0 torch.Size([256])
layers.word_encoder.rnn.weight_ih_l0_reverse torch.Size([256, 256])
layers.word_encoder.rnn.weight_hh_l0_reverse torch.Size([256, 256])
layers.word_encoder.rnn.bias_ih_l0_reverse torch.Size([256])
layers.word_encoder.rnn.bias_hh_l0_reverse torch.Size([256])
layers.out_proj.weight torch.Size([7024, 512])
layers.out_proj.bias torch.Size([7024])
layers.map_x.mlp.linear_0.weight torch.Size([512, 768])
layers.map_x.mlp.linear_0.bias torch.Size([512])
layers.context.conv.weight torch.Size([512, 512, 3])
layers.context.conv.bias torch.Size([512])
layers.context.linear.weight torch.Size([512, 1024])
layers.context.linear.bias torch.Size([512])
layers.dec_init_pre.mlp.linear_0.weight torch.Size([506, 1536])
layers.dec_init_pre.mlp.linear_0.bias torch.Size([506])
params num: 31
building data for dae...
193461
34210
train batch num: 1512
valid batch num: 268
Traceback (most recent call last):
File "train.py", line 79, in
main()
File "train.py", line 74, in main
pretrain(mixpoet, tool, hps)
File "train.py", line 30, in pretrain
dae_trainer.train(mixpoet, tool)
File "/content/MixPoet/codes/dae_trainer.py", line 137, in train
self.run_train(mixpoet, tool, optimizer, logger)
File "/content/MixPoet/codes/dae_trainer.py", line 83, in run_train
batch_keys, batch_poems, batch_dec_inps, batch_lengths)
File "/content/MixPoet/codes/dae_trainer.py", line 58, in run_step
mixpoet.dae_graph(keys, poems, dec_inps, lengths)
File "/content/MixPoet/codes/graphs.py", line 337, in dae_graph
_, poem_state0 = self.computer_enc(poems, self.layers['encoder'])
File "/content/MixPoet/codes/graphs.py", line 232, in computer_enc
enc_outs, enc_state = encoder(emb_inps, lengths)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/content/MixPoet/codes/layers.py", line 59, in forward
input_lens, batch_first=True, enforce_sorted=False)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/utils/rnn.py", line 249, in pack_padded_sequence
_VF._pack_padded_sequence(input, lengths, batch_first)
RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor, but got 1D cuda:0 Long tensor

Error About "pingshui_amb.pkl Not Found"

您好,
我在测试加载“./data/pingshui_amb.pkl”时出现错误,
报无法找到pingshui_amb.pkl文件,

后发现这个是大文件,使用git lfs后报如下错误:
Git LFS: (0 of 1 files) 0 B / 115.47 MB
batch response: This repository is over its data quota. Account responsible for LFS bandwidth should purchase more data packs to restore access.
error: failed to fetch some objects from 'https://github.com/THUNLP-AIPoet/Datasets.git/info/lfs'

希望能够帮忙看一下这个问题,谢谢!

entropy loss

layers.py line 480,Is there a mistake?

entropy = torch.log(probs+1e-10) * probs # (B, n_class)
# should it be : ?
entropy = - torch.log(probs+1e-10) * probs # (B, n_class)

Error about "compute_prior"

非常感谢你们团队所做的这项工作!
在我进行训练的时候出现了一个错误,
image
我查看了graph.py文件,在里面没有找到关于compute_prior()函数的定义,希望方便时可以解答下,谢谢~

TypeError: got an unexpected keyword argument 'quality'

Traceback (most recent call last):
File "train.py", line 80, in
main()
File "train.py", line 76, in main
train(mixpoet, tool, hps)
File "train.py", line 53, in train
mix_trainer.train(mixpoet, tool)
File "C:\Users\yingf\Desktop\MixPoet\codes\mix_trainer.py", line 304, in train
File "C:\Users\yingf\Desktop\MixPoet\codes\mix_trainer.py", line 247, in run_train
logger.draw_curves()
File "C:\Users\yingf\Desktop\MixPoet\codes\logger.py", line 422, in draw_curves
fig.savefig(self.fig_path+"/latent_distance.png", dpi=300, quality=100, bbox_inches="tight")
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\site-packages\matplotlib\figure.py", line 3274, in savefig
self.canvas.print_figure(fname, **kwargs)
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\site-packages\matplotlib\backend_bases.py", line 2338, in print_figure
result = print_method(
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\site-packages\matplotlib\backend_bases.py", line 2204, in
print_method = functools.wraps(meth)(lambda *args, **kwargs: meth(
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\site-packages\matplotlib_api\deprecation.py", line 385, in wrapper
arguments = signature.bind(*inner_args, **inner_kwargs).arguments
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\inspect.py", line 3037, in bind
return self._bind(args, kwargs)
File "C:\Users\yingf\anaconda3\envs\MixPoet\lib\inspect.py", line 3026, in _bind
raise TypeError(
TypeError: got an unexpected keyword argument 'quality'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.