Coder Social home page Coder Social logo

Comments (8)

yysirs avatar yysirs commented on August 22, 2024

if args.status == 'train': model = nn.DataParallel(model, device_ids=[0, 1]) device = [0, 1] if torch.cuda.device_count() > 1 else [0] trainer = Trainer(datasets['train'], model, optimizer, loss, args.batch, n_epochs=args.epoch, dev_data=datasets['dev'], metrics=metrics, callbacks=callbacks, dev_batch_size=args.test_batch, test_use_tqdm=False, check_code_level=-1, update_every=args.update_every, save_path="./model") trainer.train()

from flat-lattice-transformer.

yysirs avatar yysirs commented on August 22, 2024

上面的代码会报错,RuntimeError: The size of tensor a (94) must match the size of tensor b (214) at non-singleton dimension 1
当然代码是通的,其他语料训练的时候单GPU可以运行。

from flat-lattice-transformer.

mirrorQAQ avatar mirrorQAQ commented on August 22, 2024

遇到了同样的问题, 在seq_len_to_mask函数中加上max_seq_len 作为第二个参数即可解决。打印出来看会发现是mask 是char长度的,需要mask的tensor是max 长度的。在seq_len_to_mask这个函数中,max_len 这个参数默认为None了。加上就好。

from flat-lattice-transformer.

currylym avatar currylym commented on August 22, 2024

@yysirs @crazymirror 请问解决了多gpu训练的问题吗?请教下改完seq_len_to_mask函数后,还需要进行什么操作吗,谢谢🙏

from flat-lattice-transformer.

currylym avatar currylym commented on August 22, 2024

报错信息

Traceback (most recent call last):                                                                                                                                                
  File "flat_main.py", line 806, in <module>
    trainer.train()
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/fastNLP/core/trainer.py", line 613, in train
    self.callback_manager.on_exception(e)
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/fastNLP/core/callback.py", line 309, in wrapper
    returns.append(getattr(callback, func.__name__)(*arg))
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/fastNLP/core/callback.py", line 505, in on_exception
    raise exception  # 抛出陌生Error
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/fastNLP/core/trainer.py", line 609, in train
    self._train()
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/fastNLP/core/trainer.py", line 664, in _train
    prediction = self._data_forward(self.model, batch_x)
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/fastNLP/core/trainer.py", line 752, in _data_forward
    y = network(**x)
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/torch/_utils.py", line 428, in reraise
    raise self.exc_type(msg)
StopIteration: Caught StopIteration in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "../V1/models.py", line 438, in forward
    bert_embed = self.bert_embedding(char_for_bert)
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "../fastNLP_module.py", line 387, in forward
    outputs = self.model(words)
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/fastNLP/embeddings/bert_embedding.py", line 370, in forward
    output_all_encoded_layers=True)
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data00/home/luyiming.ez4curry/.local/lib/python3.6/site-packages/fastNLP/modules/encoder/bert.py", line 503, in forward
    extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
StopIteration

from flat-lattice-transformer.

HuStanding avatar HuStanding commented on August 22, 2024

@yysirs @crazymirror 请问解决了多gpu训练的问题吗?请教下改完seq_len_to_mask函数后,还需要进行什么操作吗,谢谢🙏

你好,请问你解决这个问题了吗?

from flat-lattice-transformer.

yysirs avatar yysirs commented on August 22, 2024

image
image

from flat-lattice-transformer.

yysirs avatar yysirs commented on August 22, 2024

@HuStanding @currylym 修改这里就好了 其他报错应该是是没修好导致的

from flat-lattice-transformer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.