Coder Social home page Coder Social logo

Question about Sequencer.lua about opennmt HOT 4 CLOSED

opennmt avatar opennmt commented on May 20, 2024
Question about Sequencer.lua

from opennmt.

Comments (4)

guillaumekln avatar guillaumekln commented on May 20, 2024 1

The difference is that you call backward on the RVNN module because it is the one exposed by the Sequencer. As you don't override the backward function, the definition from nn.Module is used:

function Module:backward(input, gradOutput, scale)
   scale = scale or 1
   self:updateGradInput(input, gradOutput)
   self:accGradParameters(input, gradOutput, scale)
   return self.gradInput
end

which expects self.gradInput to be not nil.

On the other hand, the LSTM module is not directly exposed by the Sequencer and it only relies on updateGradInput's return value. See https://github.com/torch/nngraph/blob/master/gmodule.lua#L420 which is called on each node in the graph.


However, these lines:

self.gradInput = self.net:updateGradInput(input, gradOutput)
return self.gradInput

should also appear in the LSTM module for consistency.

So thank you for your question.

from opennmt.

guillaumekln avatar guillaumekln commented on May 20, 2024

How do you initialize the Sequencer?

from opennmt.

helson73 avatar helson73 commented on May 20, 2024

@guillaumekln
I initialize sequencer like this

local Tree, parent = torch.class('onmt.Tree', 'onmt.Sequencer')

function Tree:__init (rvnn)
  self.rvnn = rvnn
  parent.__init(self, self.rvnn)
  self:resetPreallocation()
end

function Tree.load(pretrained)
  local self = torch.factory('onmt.Tree')
  self.rvnn = pretrained.modules[1]
  parent.__init(self, self.rvnn)
  self:resetPreallocation()
end

function Tree:training()
  parent.training(self)
end

function Tree:evaluate()
  parent.evaluate(self)
end

function Tree:serialize()
  return {
    modules = self.modules
  }
end

function Tree:maskPadding()
  self.maskPad = true
end

function Tree:resetPreallocation()
  self.headProto = torch.Tensor()
  self.depProto = torch.Tensor()
  self.gradFeedProto = torch.Tensor()
end

function Tree:forward(batch, f2s_)
  if self.train then
    self.inputs = {}
    self:_reset_noise()
  end

  local head_ = onmt.utils.Tensor.reuseTensor(self.headProto,
                                              {batch.size, self.rvnn.outSize})
  local dep_ = onmt.utils.Tensor.reuseTensor(self.depProto,
                                              {batch.size, self.rvnn.outSize})

  for t = 1, batch.headLength do
    onmt.utils.DepTree._get(head_, f2s_, batch.head[t])
    onmt.utils.DepTree._get(dep_, f2s_, batch.dep[t])
    local tree_input = {head_, dep_, batch.relation[t]}
    if self.train then
      self.inputs[t] = tree_input
    end
    onmt.utils.DepTree._set(f2s_, self:net(t):forward(tree_input), batch.update[t])
  end
  return f2s_
end

function Tree:backward(batch, gradFeedOutput)
  local gradFeed_ = onmt.utils.Tensor.reuseTensor(self.gradFeedProto,
                                                  {batch.size, self.rvnn.outSize})
  for t = batch.headLength, 1, -1 do
    onmt.utils.DepTree._get(gradFeed_, gradFeedOutput, batch.update[t])
    local dtree = self:net(t):backward(self.inputs[t], gradFeed_)
    onmt.utils.DepTree._add(gradFeedOutput, dtree[1], batch.head[t])
    onmt.utils.DepTree._add(gradFeedOutput, dtree[2], batch.dep[t])
    onmt.utils.DepTree._fill(gradFeedOutput, 0, batch.update[t])
  end
  return gradFeedOutput
end

from opennmt.

helson73 avatar helson73 commented on May 20, 2024

@guillaumekln That's really helpful. Thanks!

from opennmt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.