Coder Social home page Coder Social logo

Comments (7)

zachteed avatar zachteed commented on September 27, 2024 3

At the lowest pyramid level, the vector returned by the lookup operator at a single point is a function of 72x72 pairs (576x576 in pixels). Yes, the search range is determined by lookup radius, but the search range covers the majority, if not all, of the image. So for Sintel resolution, ~50% of the all-pairs correlation are being used at a particular step.

The only reason the efficient implementation exists is because we use the dot product for correlation, and average pooling for aggregation. Both these operations are linear. Swap either of these operators with a non-linear function (i.e. L1 distance, max pooling), and our "efficient implementation" no longer exists.

The code I provided above is ~8x slower than our all-pairs implementation. So even when an implementation with lower complexity exists, there are practical advantages to our formulation.

from raft.

ProNoobLi avatar ProNoobLi commented on September 27, 2024

Hi,
To verify my opinion, I implement a very naive efficient approach, but the result becomes worse.
image
Warped results from the predicted_flow result from the original code(all pairs). Such warped results are masked by the occlusion mask.
image
Warped results from the predicted_flow result from my code(local neighbor search). Such warped results are masked by the occlusion mask.

I can see that my implementation is much worse than yours, while my implementation is based on my opinion in the post above. The code is attached here, and please feel free to point out what I misunderstood.
image
image

from raft.

ProNoobLi avatar ProNoobLi commented on September 27, 2024

Let me explain here.
For your all-pairs correlation. The method and implementation is straightforward and easy to understand.

  1. Let's assume the fmap1 and fmap2 [B, C, H, W] = [B, 128, 48, 128]
  2. Get an all-pairs correlation by calculating a 6144*6144 matrix multiplication.
  3. Downsample such matrix to get the corr_pyramid for 4 levels
  4. Lookup the correlation values of a 7*7 neighbors in fmap2 surrounding each pix at fmap1 .
  5. Cat four correlation feature maps [48, 128, 49]*4 and return a [48, 128, 196] feature.
    From my understanding, though the all-pairs correlation is calculated, only 196 out of 6144 times dot product are utilized. Thus ideally, there exists an alternative implementation takes only O(NM) as your paper mentioned.

For my alternative implementation. I tried to understand your efficient way but I am not sure it is correct.

  1. Downsample the fmap2 by ratio [1, 2, 4, 8]
  2. At each level, loop each feature of fmap1
  3. Grid sample the corresponding neighbor [7, 7, 128] in fmap2 by the flow_coords
  4. Calculate the dot product between the feature and its corresponding neighbor(after flow) and get a [1, 49] feature
  5. Repeat steps 2 3 4, and finally slides over the fmap1 to get a feature map [48, 128, 49] * 4.

I know such alternative implementation above is not "efficient" since highly relies on a double for loops. But I just want to verify the relationship between the local lookups and the all-pairs.

from raft.

zachteed avatar zachteed commented on September 27, 2024

Hi, I'm having trouble following. But, here is some code that will do the job. Just plug this in instead of the default CorrBlock. You can check that it gives exactly (to numerical precision) results as the default implementation. There is no need to loop over the neighboorhood, you can just use the built in grid sampler. Again, this is very slow, you will need to write a cuda implementation for it to be practical.

class CorrBlock:
    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
        self.num_levels = num_levels
        self.radius = radius
        self.corr_pyramid = []

        self.fmap1 = fmap1
        self.fmap2 = fmap2

    def corr(self, fmap1, fmap2, coords):

        B, D, H, W = fmap2.shape
        fmap1 = fmap1.unsqueeze(dim=-1)
        fmap2 = fmap2.unsqueeze(dim=-1)

        # map grid coordinates to [-1,1]
        xgrid, ygrid = coords.split([1,1], dim=-1)
        xgrid = 2*xgrid/(W-1) - 1
        ygrid = 2*ygrid/(H-1) - 1
        zgrid = torch.zeros_like(xgrid) - 1
        grid = torch.cat([zgrid, xgrid, ygrid], dim=-1)

        fmapw = F.grid_sample(fmap2, grid, align_corners=True)

        corr = torch.sum(fmap1*fmapw, dim=1)
        return corr / torch.sqrt(torch.tensor(D).float())

    def __call__(self, coords):

        r = self.radius
        coords = coords.permute(0, 2, 3, 1)
        batch, h1, w1, _ = coords.shape

        fmap1 = self.fmap1
        fmap2 = self.fmap2

        out_pyramid = []
        for i in range(self.num_levels):
            dx = torch.linspace(-r, r, 2*r+1)
            dy = torch.linspace(-r, r, 2*r+1)
            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)

            centroid_lvl = coords.reshape(batch, h1, w1, 1, 2) / 2**i
            coords_lvl = centroid_lvl + delta.view(-1, 2)

            corr = self.corr(fmap1, fmap2, coords_lvl)
            fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
            out_pyramid.append(corr)

        out = torch.cat(out_pyramid, dim=-1)
        return out.permute(0, 3, 1, 2).contiguous().float()

from raft.

zachteed avatar zachteed commented on September 27, 2024

I think the problem with your code was that you weren't scaling the correlation features by 1.0/sqrt(D)

from raft.

ProNoobLi avatar ProNoobLi commented on September 27, 2024

Hi, I'm having trouble following. But, here is some code that will do the job. Just plug this in instead of the default CorrBlock. You can check that it gives exactly (to numerical precision) results as the default implementation. There is no need to loop over the neighboorhood, you can just use the built in grid sampler. Again, this is very slow, you will need to write a cuda implementation for it to be practical.

class CorrBlock:
    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
        self.num_levels = num_levels
        self.radius = radius
        self.corr_pyramid = []

        self.fmap1 = fmap1
        self.fmap2 = fmap2

    def corr(self, fmap1, fmap2, coords):

        B, D, H, W = fmap2.shape
        fmap1 = fmap1.unsqueeze(dim=-1)
        fmap2 = fmap2.unsqueeze(dim=-1)

        # map grid coordinates to [-1,1]
        xgrid, ygrid = coords.split([1,1], dim=-1)
        xgrid = 2*xgrid/(W-1) - 1
        ygrid = 2*ygrid/(H-1) - 1
        zgrid = torch.zeros_like(xgrid) - 1
        grid = torch.cat([zgrid, xgrid, ygrid], dim=-1)

        fmapw = F.grid_sample(fmap2, grid, align_corners=True)

        corr = torch.sum(fmap1*fmapw, dim=1)
        return corr / torch.sqrt(torch.tensor(D).float())

    def __call__(self, coords):

        r = self.radius
        coords = coords.permute(0, 2, 3, 1)
        batch, h1, w1, _ = coords.shape

        fmap1 = self.fmap1
        fmap2 = self.fmap2

        out_pyramid = []
        for i in range(self.num_levels):
            dx = torch.linspace(-r, r, 2*r+1)
            dy = torch.linspace(-r, r, 2*r+1)
            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)

            centroid_lvl = coords.reshape(batch, h1, w1, 1, 2) / 2**i
            coords_lvl = centroid_lvl + delta.view(-1, 2)

            corr = self.corr(fmap1, fmap2, coords_lvl)
            fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
            out_pyramid.append(corr)

        out = torch.cat(out_pyramid, dim=-1)
        return out.permute(0, 3, 1, 2).contiguous().float()

Thank you for the code. It's straightforward to understand.
Here is my question: anyway you grid sample the neighbor of each feature, thus all-pairs correlation isn't utilized in the following step, right? I think it still be the local correlation search and the search range is determined by lookup radius.

from raft.

ProNoobLi avatar ProNoobLi commented on September 27, 2024

Great! I see. It sounds like the receptive area at low level covers a wider range over the high level, but here all information is not from CNN but correlation.

PWC-net and flownet downsample the feature map, but they use local search at each level, while yours uses all-pairs at each level.

It makes sense! A new idea. Congrats

from raft.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.