Coder Social home page Coder Social logo

Problem of backward of APN about racnn-pytorch HOT 1 OPEN

jeong-tae avatar jeong-tae commented on August 22, 2024
Problem of backward of APN

from racnn-pytorch.

Comments (1)

YiKeYaTu avatar YiKeYaTu commented on August 22, 2024

Thanks for sharing the code. It helps me understand the APN, I have been confused by how the author crops the attention region.

In the backward code of APN, I found you used a fixed value of in_size. (If my understanding for the code is right) Did you just backpropagate the gradient to a fixed location? if it is fixed, why did you do that? If not, how you backpropagate the gradient to the attention location?

Thanks in advance

def backward(self, grad_output):
    images, ret_tensor = self.saved_variables[0], self.saved_variables[1]
    in_size = 224
    ret = torch.Tensor(grad_output.size(0), 3).zero_()
    norm = -(grad_output * grad_output).sum(dim=1)
  
    
    x = torch.stack([torch.arange(0, in_size)] * in_size).t()
    y = x.t()
    long_size = (in_size/3*2)
    short_size = (in_size/3)
    mx = (x >= long_size).float() - (x < short_size).float()
    my = (y >= long_size).float() - (y < short_size).float()
    ml = (((x<short_size)+(x>=long_size)+(y<short_size)+(y>=long_size)) > 0).float()*2 - 1
    
    mx_batch = torch.stack([mx.float()] * grad_output.size(0))
    my_batch = torch.stack([my.float()] * grad_output.size(0))
    ml_batch = torch.stack([ml.float()] * grad_output.size(0))
    
    if isinstance(grad_output, torch.cuda.FloatTensor):
        mx_batch = mx_batch.cuda()
        my_batch = my_batch.cuda()
        ml_batch = ml_batch.cuda()
        ret = ret.cuda()
    
    ret[:, 0] = (norm * mx_batch).sum(dim=1).sum(dim=1)
    ret[:, 1] = (norm * my_batch).sum(dim=1).sum(dim=1)
    ret[:, 2] = (norm * ml_batch).sum(dim=1).sum(dim=1)
    return None, ret

I have met the same problem. Have you figured this out yet?

from racnn-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.