Coder Social home page Coder Social logo

Comments (8)

hanqiu-hq avatar hanqiu-hq commented on June 29, 2024

Yes, the tensor h_w has a shape of [1, bs, 2].
Did you delete the brackets of the tensor gaussian?

gaussian=[gaussian])[0]

from smca-detr.

lucasjinreal avatar lucasjinreal commented on June 29, 2024

@hanqiu-hq Hi, the error not comes from transformer decoder, but from SMCA attention forward:

 if naive:
        attn_output_weights = torch.bmm(q, k.transpose(1, 2))
        assert list(attn_output_weights.size()) == [
            bsz * num_heads, tgt_len, src_len]

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_output_weights.masked_fill_(attn_mask, float('-inf'))
            else:
                attn_output_weights += attn_mask

        if key_padding_mask is not None:
            attn_output_weights = attn_output_weights.view(
                bsz, num_heads, tgt_len, src_len)
            attn_output_weights = attn_output_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'),)
            attn_output_weights = attn_output_weights.view(
                bsz * num_heads, tgt_len, src_len)

        attn_output_weights = attn_output_weights + \
            gaussian[0].permute(2, 0, 1)

attn_output_weights = attn_output_weights + gaussian[0].permute(2, 0, 1)

which means gaussian[0] 's shape is not right, but tthis value how could it be wrong? images shape not problem, h_w no problem..

from smca-detr.

hanqiu-hq avatar hanqiu-hq commented on June 29, 2024

The gaussian tensor has three dimensions with the shape of [tgt_len, H*W, bs * num_head].
It is wrapped in a list when passing to the attention module.

gaussian=[gaussian])[0]

So the gaussian[0] in the attention module should have the same shape and num of dimensions as above.
you can try to print the shape of "gaussian[0]" in the attention module.

attn_output_weights = attn_output_weights + gaussian[0].permute(2, 0, 1)

from smca-detr.

lucasjinreal avatar lucasjinreal commented on June 29, 2024

@hanqiu-hq My guassian shape is:

gaussian shape:  torch.Size([100, 962, 80, 2])

seems wrong, seems the last shape [80, 2] is not right? the 2 is reduandent, what caused this error?

from smca-detr.

lucasjinreal avatar lucasjinreal commented on June 29, 2024

@hanqiu-hq And I think , maybe h_w shape is not as you told.... every variable shape I logged:

point shape:  torch.Size([100, 80, 2])
distance shape:  torch.Size([100, 858, 80, 2])
gaussian shape 0:  torch.Size([100, 858, 80, 2])
gaussian shape:  torch.Size([100, 858, 80, 2])
point shape:  torch.Size([100, 80, 2])
distance shape:  torch.Size([100, 864, 80, 2])
gaussian shape 0:  torch.Size([100, 864, 80, 2])
gaussian shape:  torch.Size([100, 864, 80, 2])

start from point shape, it was wrong..
from the code:

if self.layer_index == 0:
            point_sigmoid_ref_inter = self.point1(out)
            point_sigmoid_ref = point_sigmoid_ref_inter.sigmoid()
            point_sigmoid_ref = (h_w - 0) * point_sigmoid_ref / 32
            point_sigmoid_ref = point_sigmoid_ref.repeat(1, 1, 8)
        else:
            point_sigmoid_ref = point_ref_previous
        print('point_sigmoid_ref: ', point_sigmoid_ref.shape)
        point = point_sigmoid_ref + point_sigmoid_offset
        point = point.view(tgt_len, -1, 2)
        distance = (point.unsqueeze(1) - grid.unsqueeze(0)).pow(2)
        print('point shape: ', point.unsqueeze(1).shape)
        print('grid shape: ', grid.unsqueeze(0).shape)
        print('distance shape: ', distance.shape)

I can make sure, your code has problem. even grid shape is right

grid 0 shape:  torch.Size([1092, 80, 2])

so which shape seems wrong here?

from smca-detr.

hanqiu-hq avatar hanqiu-hq commented on June 29, 2024

The last dimension of tensor distance with shape [2] is eliminated at:

distance = (distance * scale).sum(-1)

So the tensor gaussian will have three dimensions with shape [tgt_len, HW, bs * num_head] instead of [tgt_len, HW, bs * num_head , 2].
Did you skip the code from L280-L298?
if self.dynamic_scale == "type1":
scale = 1
distance = distance.sum(-1) * scale
elif self.dynamic_scale == "type2":
scale = self.point3(out)
scale = scale * scale
scale = scale.reshape(tgt_len, -1).unsqueeze(1)
distance = distance.sum(-1) * scale
elif self.dynamic_scale == "type3":
scale = self.point3(out)
scale = scale * scale
scale = scale.reshape(tgt_len, -1, 2).unsqueeze(1)
distance = (distance * scale).sum(-1)
elif self.dynamic_scale == "type4":
scale = self.point3(out)
scale = scale * scale
scale = scale.reshape(tgt_len, -1, 3).unsqueeze(1)
distance = torch.cat([distance, torch.prod(distance, dim=-1, keepdim=True)], dim=-1)
distance = (distance * scale).sum(-1)

from smca-detr.

lucasjinreal avatar lucasjinreal commented on June 29, 2024

@hanqiu-hq Holy shit... seems you set a value of a string to a default bool value....

image

What should I set here? type1 or type4?

from smca-detr.

lucasjinreal avatar lucasjinreal commented on June 29, 2024

@hanqiu-hq I started trained, the speed seems not very normal:

image

from smca-detr.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.