Comments (7)
At the lowest pyramid level, the vector returned by the lookup operator at a single point is a function of 72x72 pairs (576x576 in pixels). Yes, the search range is determined by lookup radius, but the search range covers the majority, if not all, of the image. So for Sintel resolution, ~50% of the all-pairs correlation are being used at a particular step.
The only reason the efficient implementation exists is because we use the dot product for correlation, and average pooling for aggregation. Both these operations are linear. Swap either of these operators with a non-linear function (i.e. L1 distance, max pooling), and our "efficient implementation" no longer exists.
The code I provided above is ~8x slower than our all-pairs implementation. So even when an implementation with lower complexity exists, there are practical advantages to our formulation.
from raft.
Hi,
To verify my opinion, I implement a very naive efficient approach, but the result becomes worse.
Warped results from the predicted_flow result from the original code(all pairs). Such warped results are masked by the occlusion mask.
Warped results from the predicted_flow result from my code(local neighbor search). Such warped results are masked by the occlusion mask.
I can see that my implementation is much worse than yours, while my implementation is based on my opinion in the post above. The code is attached here, and please feel free to point out what I misunderstood.
from raft.
Let me explain here.
For your all-pairs correlation. The method and implementation is straightforward and easy to understand.
- Let's assume the fmap1 and fmap2 [B, C, H, W] = [B, 128, 48, 128]
- Get an all-pairs correlation by calculating a 6144*6144 matrix multiplication.
- Downsample such matrix to get the corr_pyramid for 4 levels
- Lookup the correlation values of a 7*7 neighbors in fmap2 surrounding each pix at fmap1 .
- Cat four correlation feature maps [48, 128, 49]*4 and return a [48, 128, 196] feature.
From my understanding, though the all-pairs correlation is calculated, only 196 out of 6144 times dot product are utilized. Thus ideally, there exists an alternative implementation takes only O(NM) as your paper mentioned.
For my alternative implementation. I tried to understand your efficient way but I am not sure it is correct.
- Downsample the fmap2 by ratio [1, 2, 4, 8]
- At each level, loop each feature of fmap1
- Grid sample the corresponding neighbor [7, 7, 128] in fmap2 by the flow_coords
- Calculate the dot product between the feature and its corresponding neighbor(after flow) and get a [1, 49] feature
- Repeat steps 2 3 4, and finally slides over the fmap1 to get a feature map [48, 128, 49] * 4.
I know such alternative implementation above is not "efficient" since highly relies on a double for loops. But I just want to verify the relationship between the local lookups and the all-pairs.
from raft.
Hi, I'm having trouble following. But, here is some code that will do the job. Just plug this in instead of the default CorrBlock. You can check that it gives exactly (to numerical precision) results as the default implementation. There is no need to loop over the neighboorhood, you can just use the built in grid sampler. Again, this is very slow, you will need to write a cuda implementation for it to be practical.
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
self.fmap1 = fmap1
self.fmap2 = fmap2
def corr(self, fmap1, fmap2, coords):
B, D, H, W = fmap2.shape
fmap1 = fmap1.unsqueeze(dim=-1)
fmap2 = fmap2.unsqueeze(dim=-1)
# map grid coordinates to [-1,1]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
zgrid = torch.zeros_like(xgrid) - 1
grid = torch.cat([zgrid, xgrid, ygrid], dim=-1)
fmapw = F.grid_sample(fmap2, grid, align_corners=True)
corr = torch.sum(fmap1*fmapw, dim=1)
return corr / torch.sqrt(torch.tensor(D).float())
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
fmap1 = self.fmap1
fmap2 = self.fmap2
out_pyramid = []
for i in range(self.num_levels):
dx = torch.linspace(-r, r, 2*r+1)
dy = torch.linspace(-r, r, 2*r+1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(batch, h1, w1, 1, 2) / 2**i
coords_lvl = centroid_lvl + delta.view(-1, 2)
corr = self.corr(fmap1, fmap2, coords_lvl)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
from raft.
I think the problem with your code was that you weren't scaling the correlation features by 1.0/sqrt(D)
from raft.
Hi, I'm having trouble following. But, here is some code that will do the job. Just plug this in instead of the default CorrBlock. You can check that it gives exactly (to numerical precision) results as the default implementation. There is no need to loop over the neighboorhood, you can just use the built in grid sampler. Again, this is very slow, you will need to write a cuda implementation for it to be practical.
class CorrBlock: def __init__(self, fmap1, fmap2, num_levels=4, radius=4): self.num_levels = num_levels self.radius = radius self.corr_pyramid = [] self.fmap1 = fmap1 self.fmap2 = fmap2 def corr(self, fmap1, fmap2, coords): B, D, H, W = fmap2.shape fmap1 = fmap1.unsqueeze(dim=-1) fmap2 = fmap2.unsqueeze(dim=-1) # map grid coordinates to [-1,1] xgrid, ygrid = coords.split([1,1], dim=-1) xgrid = 2*xgrid/(W-1) - 1 ygrid = 2*ygrid/(H-1) - 1 zgrid = torch.zeros_like(xgrid) - 1 grid = torch.cat([zgrid, xgrid, ygrid], dim=-1) fmapw = F.grid_sample(fmap2, grid, align_corners=True) corr = torch.sum(fmap1*fmapw, dim=1) return corr / torch.sqrt(torch.tensor(D).float()) def __call__(self, coords): r = self.radius coords = coords.permute(0, 2, 3, 1) batch, h1, w1, _ = coords.shape fmap1 = self.fmap1 fmap2 = self.fmap2 out_pyramid = [] for i in range(self.num_levels): dx = torch.linspace(-r, r, 2*r+1) dy = torch.linspace(-r, r, 2*r+1) delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) centroid_lvl = coords.reshape(batch, h1, w1, 1, 2) / 2**i coords_lvl = centroid_lvl + delta.view(-1, 2) corr = self.corr(fmap1, fmap2, coords_lvl) fmap2 = F.avg_pool2d(fmap2, 2, stride=2) out_pyramid.append(corr) out = torch.cat(out_pyramid, dim=-1) return out.permute(0, 3, 1, 2).contiguous().float()
Thank you for the code. It's straightforward to understand.
Here is my question: anyway you grid sample the neighbor of each feature, thus all-pairs correlation isn't utilized in the following step, right? I think it still be the local correlation search and the search range is determined by lookup radius.
from raft.
Great! I see. It sounds like the receptive area at low level covers a wider range over the high level, but here all information is not from CNN but correlation.
PWC-net and flownet downsample the feature map, but they use local search at each level, while yours uses all-pairs at each level.
It makes sense! A new idea. Congrats
from raft.
Related Issues (20)
- How to use it to calculate the optical flow HOT 8
- Google authenticator installation
- SmartVectors / stmaps
- load image problem
- RuntimeError: GET was unable to find an engine to execute this computation HOT 2
- Raft Optical Flow does not look correct HOT 1
- How to finetune pretrained mdoel on custom dataset to predict relative transformation of poses?
- wget doesn't work for me HOT 2
- Flow initialization question
- CUDA 11 support for RTX 3080/3090
- IndexError: list index out of range
- ModuleNotFoundError: No module named 'utils.utils' HOT 2
- Relationship between Real Image size and Crop Size parameter?
- `augmentor.resize_sparse_flow_map` does not work for non-sparse flow-maps
- can not create environment
- Error when running on 2 GPUs
- RuntimeError: CUDA out of memory HOT 3
- How to train my own dataset HOT 5
- How to calc optical flow score
- About Applications
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from raft.