Coder Social home page Coder Social logo

forward-warp's Introduction

Foward Warp Pytorch Version

Has been tested in pytorch=0.4.0, python=3.6, CUDA=9.0

Install

export CUDA_HOME=/usr/local/cuda #use your CUDA instead
chmod a+x install.sh
./install.sh

Test

cd test
python test.py

Usage

from Forward_Warp import forward_warp

# default interpolation mode is Bilinear
fw = forward_warp()
im2_bilinear = fw(im0, flow) 
# use interpolation mode Nearest
# Notice: Nearest input-flow's gradient will be zero when at backward.
fw = forward_warp(interpolation_mode="Nearest")  
im2_nearest = fw(im0, flow) 

forward-warp's People

Contributors

lizhihao6 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

forward-warp's Issues

The error in forward-propagation

Thanks for providing the code of forward warping operation in pytorch. But I get some errors when I add the forward_warp to a nn.Module. The error occurs in the forward-propagation. Here is the code and the error information.

import torch
import os
import torch.nn as nn

from Forward_Warp import forward_warp


class FW(nn.Module):
    def __init__(self):
        super(FW, self).__init__()
        self.conv = nn.Conv2d(3, 2, 3, 1, 1)
        self.fw = forward_warp(interpolation_mode='Bilinear')

    def forward(self, x):
        flow = self.conv(x).permute(0, 2, 3, 1).contiguous()
        return self.fw(x, flow)


if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    criterion_train = torch.nn.L1Loss()
    model = FW()

    x = torch.randn(4, 3, 32, 32).cuda().contiguous()
    y = torch.randn(4, 3, 32, 32).cuda()
    model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    for i in range(10):
        y_pred = model(x)
        loss = criterion_train(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(loss.item())
Traceback (most recent call last):
  File "Forward-Warp-master/test/test_new.py", line 31, in <module>
    y_pred = model(x)
  File "/data/anaconda3/envs/pytorch1.4/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "Forward-Warp-master/test/test_new.py", line 16, in forward
    return self.fw(x, flow)
  File "/data/anaconda3/envs/pytorch1.4/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/data/anaconda3/envs/pytorch1.4/lib/python3.6/site-packages/Forward_Warp-0.0.1-py3.6.egg/Forward_Warp/forward_warp.py", line 59, in forward
TypeError: save_for_backward can only save variables, but argument 2 is of type int

Moreover, if with torch.no_grad(), the code works.
The flow is calculated by CNN, does it causes the problem?

The error is occur both in pytorch1.4 and pytorch1.7, with cuda 10.1 and python 3.6.

illegal memory access

I got this error when using this library to optimize flow, with lpips loss

RuntimeError: CUDA error: an illegal memory access was encountered

This problem maybe related to my loss, my torch version and cuda version, but I finally solved it by changing

auto flow_grad = at::empty_like(flow);

to

  auto flow_grad = at::zeros_like(flow);

As in the backward_kernel function, only "not outside" position's flow_grad will be update, while "outside" position will remain the value inited by "empty_like". These parts are random values I think.

possible holes in forward warp?

Hi thanks for sharing this work. I am curious that how are the holes in handled in forward warp? for example this 4 pixels
(0,0) (1,0)
(0,1) (1,1)
are warped to
(0,0) (5,0)
(0,5) (5,5)
according to the python code, I cannot figure out how pixels in the middle are filled.

N-to-1 mapping

Hi @lizhihao6 , first thank you for this great repo.
I just began to use it and found the forward-warpped image seems to contain some large values. After checking the python implementation it seems that it simply accumulates all the pixel values mapped to each pixel? It's common that mulitple source pixels can be forward-mapped to the same target pixel. In that case, a weighted average (instead of adding them all up) seems to be more reasonable? But in order to do that, we have to keep counting the (fractional) counts of source pixels that are mapped to each target pixel.
Thanks.

docker install error

I use the following command install with docker:

FROM <base_image>

# set the working directory
WORKDIR /app

# install git
RUN apt-get update && apt-get install -y git

# clone the project
RUN git clone https://github.com/lizhihao6/Forward-Warp.git

# set CUDA_HOME
ENV CUDA_HOME /usr/local/cuda

# install the project
WORKDIR /app/Forward-Warp
RUN chmod a+x install.sh
RUN ./install.sh

but get an error:

#11 [8/8] RUN ./install.sh
#11 0.417 usage: conda [-h] [--no-plugins] [-V] COMMAND ...
#11 0.417 conda: error: argument COMMAND: invalid choice: 'activate' (choose from 'clean', 'compare', 'config', 'create', 'info', 'init', 'install', 'list', 'notices', 'package', 'remove', 'uninstall', 'rename', 'run', 'search', 'update', 'upgrade', 'doctor', 'env', 'content-trust')
#11 2.247 No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
#11 2.289 /opt/miniconda/lib/python3.10/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
#11 2.289   warnings.warn(
#11 2.392 /opt/miniconda/lib/python3.10/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.
#11 2.392   warnings.warn(
#11 2.617 /opt/miniconda/lib/python3.10/site-packages/torch/cuda/__init__.py:546: UserWarning: Can't initialize NVML
#11 2.617   warnings.warn("Can't initialize NVML")
#11 2.617 Traceback (most recent call last):
#11 2.617   File "/app/Forward-Warp/Forward_Warp/cuda/setup.py", line 4, in <module>
#11 2.617     setup(
#11 2.617   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/__init__.py", line 87, in setup
#11 2.617     return distutils.core.setup(**attrs)
#11 2.617   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 185, in setup
#11 2.617     return run_commands(dist)
#11 2.617   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 201, in run_commands
#11 2.618     dist.run_commands()
#11 2.618   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 969, in run_commands
#11 2.618     self.run_command(cmd)
#11 2.618   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/dist.py", line 1208, in run_command
#11 2.618     super().run_command(command)
#11 2.618   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
#11 2.618     cmd_obj.run()
#11 2.618   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/command/install.py", line 74, in run
#11 2.618     self.do_egg_install()
#11 2.618   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/command/install.py", line 123, in do_egg_install
#11 2.618     self.run_command('bdist_egg')
#11 2.618   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
#11 2.619     self.distribution.run_command(command)
#11 2.619   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/dist.py", line 1208, in run_command
#11 2.619     super().run_command(command)
#11 2.619   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
#11 2.619     cmd_obj.run()
#11 2.619   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/command/bdist_egg.py", line 165, in run
#11 2.619     cmd = self.call_command('install_lib', warn_dir=0)
#11 2.619   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/command/bdist_egg.py", line 151, in call_command
#11 2.619     self.run_command(cmdname)
#11 2.619   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
#11 2.619     self.distribution.run_command(command)
#11 2.619   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/dist.py", line 1208, in run_command
#11 2.620     super().run_command(command)
#11 2.620   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
#11 2.620     cmd_obj.run()
#11 2.620   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/command/install_lib.py", line 11, in run
#11 2.620     self.build()
#11 2.620   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/command/install_lib.py", line 112, in build
#11 2.620     self.run_command('build_ext')
#11 2.620   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
#11 2.620     self.distribution.run_command(command)
#11 2.620   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/dist.py", line 1208, in run_command
#11 2.620     super().run_command(command)
#11 2.620   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
#11 2.621     cmd_obj.run()
#11 2.621   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 84, in run
#11 2.621     _build_ext.run(self)
#11 2.621   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 346, in run
#11 2.621     self.build_extensions()
#11 2.621   File "/opt/miniconda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 843, in build_extensions
#11 2.621     build_ext.build_extensions(self)
#11 2.621   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 468, in build_extensions
#11 2.621     self._build_extensions_serial()
#11 2.621   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 494, in _build_extensions_serial
#11 2.621     self.build_extension(ext)
#11 2.621   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 246, in build_extension
#11 2.621     _build_ext.build_extension(self, ext)
#11 2.621   File "/opt/miniconda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 549, in build_extension
#11 2.622     objects = self.compiler.compile(
#11 2.622   File "/opt/miniconda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 649, in unix_wrap_ninja_compile
#11 2.622     cuda_post_cflags = unix_cuda_flags(cuda_post_cflags)
#11 2.622   File "/opt/miniconda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 548, in unix_cuda_flags
#11 2.622     cflags + _get_cuda_arch_flags(cflags))
#11 2.622   File "/opt/miniconda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 1773, in _get_cuda_arch_flags
#11 2.622     arch_list[-1] += '+PTX'
#11 2.622 IndexError: list index out of range
#11 3.215 running install
#11 3.215 /opt/miniconda/lib/python3.10/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
#11 3.215   warnings.warn(
#11 3.266 /opt/miniconda/lib/python3.10/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.
#11 3.266   warnings.warn(
#11 3.408 running bdist_egg
#11 3.432 running egg_info
#11 3.432 creating Forward_Warp.egg-info
#11 3.439 writing Forward_Warp.egg-info/PKG-INFO
#11 3.439 writing dependency_links to Forward_Warp.egg-info/dependency_links.txt
#11 3.439 writing top-level names to Forward_Warp.egg-info/top_level.txt
#11 3.440 writing manifest file 'Forward_Warp.egg-info/SOURCES.txt'
#11 3.455 reading manifest file 'Forward_Warp.egg-info/SOURCES.txt'
#11 3.455 adding license file 'LICENSE'
#11 3.456 writing manifest file 'Forward_Warp.egg-info/SOURCES.txt'
#11 3.456 installing library code to build/bdist.linux-x86_64/egg
#11 3.456 running install_lib
#11 3.456 running build_py
#11 3.456 creating build
#11 3.456 creating build/lib
#11 3.457 creating build/lib/Forward_Warp
#11 3.457 copying Forward_Warp/forward_warp.py -> build/lib/Forward_Warp
#11 3.457 copying Forward_Warp/__init__.py -> build/lib/Forward_Warp
#11 3.457 creating build/lib/Forward_Warp/python
#11 3.457 copying Forward_Warp/python/__init__.py -> build/lib/Forward_Warp/python
#11 3.457 copying Forward_Warp/python/forward_warp_python.py -> build/lib/Forward_Warp/python
#11 3.458 creating build/bdist.linux-x86_64
#11 3.458 creating build/bdist.linux-x86_64/egg
#11 3.458 creating build/bdist.linux-x86_64/egg/Forward_Warp
#11 3.458 creating build/bdist.linux-x86_64/egg/Forward_Warp/python
#11 3.458 copying build/lib/Forward_Warp/python/__init__.py -> build/bdist.linux-x86_64/egg/Forward_Warp/python
#11 3.458 copying build/lib/Forward_Warp/python/forward_warp_python.py -> build/bdist.linux-x86_64/egg/Forward_Warp/python
#11 3.459 copying build/lib/Forward_Warp/forward_warp.py -> build/bdist.linux-x86_64/egg/Forward_Warp
#11 3.459 copying build/lib/Forward_Warp/__init__.py -> build/bdist.linux-x86_64/egg/Forward_Warp
#11 3.459 byte-compiling build/bdist.linux-x86_64/egg/Forward_Warp/python/__init__.py to __init__.cpython-310.pyc
#11 3.460 byte-compiling build/bdist.linux-x86_64/egg/Forward_Warp/python/forward_warp_python.py to forward_warp_python.cpython-310.pyc
#11 3.461 byte-compiling build/bdist.linux-x86_64/egg/Forward_Warp/forward_warp.py to forward_warp.cpython-310.pyc
#11 3.462 build/bdist.linux-x86_64/egg/Forward_Warp/forward_warp.py:53: SyntaxWarning: "is" with a literal. Did you mean "=="?
#11 3.462   if(interpolation_mode is "Bilinear"):
#11 3.462 byte-compiling build/bdist.linux-x86_64/egg/Forward_Warp/__init__.py to __init__.cpython-310.pyc
#11 3.462 creating build/bdist.linux-x86_64/egg/EGG-INFO
#11 3.463 copying Forward_Warp.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
#11 3.463 copying Forward_Warp.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
#11 3.463 copying Forward_Warp.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
#11 3.463 copying Forward_Warp.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
#11 3.463 zip_safe flag not set; analyzing archive contents...
#11 3.464 creating dist
#11 3.464 creating 'dist/Forward_Warp-0.0.1-py3.10.egg' and adding 'build/bdist.linux-x86_64/egg' to it
#11 3.466 removing 'build/bdist.linux-x86_64/egg' (and everything under it)
#11 3.467 Processing Forward_Warp-0.0.1-py3.10.egg
#11 3.468 Copying Forward_Warp-0.0.1-py3.10.egg to /opt/miniconda/lib/python3.10/site-packages
#11 3.470 Adding Forward-Warp 0.0.1 to easy-install.pth file
#11 3.471
#11 3.471 Installed /opt/miniconda/lib/python3.10/site-packages/Forward_Warp-0.0.1-py3.10.egg
#11 3.472 Processing dependencies for Forward-Warp==0.0.1
#11 3.472 Finished processing dependencies for Forward-Warp==0.0.1
#11 DONE 3.5s

Solved some compilation warning when using PyTorch 1.7.1 with CUDA 10.2

Thanks for sharing the codes, which are good examples of how to implement and access 4D tensors (like, in size [N, C, H, W]) in a low-level CUDA kernel (e.g., idx=blockIdx.x * blockDim.x + threadIdx.x).

I would like to share the following:

  1. Uncomment the macro definitions CHECK_CUDA, CHECK_CONTIGUOUS, and CHECK_INPUT. Otherwise, you probably get incorrect results when you run the test/test.py. Since the tensors generated by e.g., im0 = torch.FloatTensor(im0).permute(0, 3, 1, 2) are not contiguous. Use im0 = torch.FloatTensor(im0).permute(0, 3, 1, 2).contiguous() instead.

  2. Compilation Deprecation Warning: the solution includes:

    • a) In cuda/forward_warp_cuda.cpp file: change "#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), ..." to "#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), ...";
    • b) In cuda/forward_warp_cuda_kernel.cu file: change "AT_DISPATCH_FLOATING_TYPES(im0.type(), ..." to "AT_DISPATCH_FLOATING_TYPES(im0.scalar_type(), ...", and im0.data<scalar_t>() to im0.data_ptr<scalar_t>();

Now I can compile the CUDA code and get the correct results when running test/test.py.

unexpected result of "assert ... or 1"

This assert's result is unexpected.
https://github.com/lizhihao6/Forward-Warp/blob/master/Forward_Warp/forward_warp.py#L19

In [7]: a=2

In [8]: assert a is 0 or 1

In [9]: assert (a is 0 or 1)

In [10]: assert a is 0
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-10-84d2b10b4afa> in <module>
----> 1 assert a is 0

AssertionError:

In [11]: assert a == 0
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-11-4321e3980dbd> in <module>
----> 1 assert a == 0

AssertionError:

In [12]: assert a == 0 or 1

In [13]: assert(a == 0 or 1)

In [15]: assert a in (0 , 1)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-15-51fbe2680d83> in <module>
----> 1 assert a in (0 , 1)

AssertionError:

In [16]: assert a in (0, 1, 2)

So I will stick to use "assert x in (...)"

Example usage

A small error in Readme - Usage:
parameter 'interpolation_mode' should be supplied during initialization as
fw = forward_warp(interpolation_mode="Nearest")

Grad of all-zero flow

In the simplest case, if the flow is a zero tensor, all source elements are copied to the same position in the target tensor. If flow changes slightly, the target tensor will also change, that means the flow gradient is non-zero.
However, this test code shows otherwise:

from Forward_Warp import forward_warp
import torch

a = torch.randn(1,1,5,5)
flow = torch.zeros(1,5,5,2, requires_grad=True)
fwarp = forward_warp()
b =fwarp(a, flow)
b.sum().backward()
print(flow.grad)
# flow.grad is an all-zero tensor.

Any idea what's the issue? Thanks.

error during install

Thank you for the amazing repo! This fucntion is absolutely a missing part from PyTorch.

However, I encountered the following error during install:

Traceback (most recent call last):
  File "/home/username/anaconda3/envs/ai/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1515, in _run_ninja_build
    env=env)
  File "/home/username/anaconda3/envs/ai/lib/python3.7/subprocess.py", line 512, in run
    output=stdout, stderr=stderr)
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

Does anyone have any ideas how to deal with it?

support flo file

How do I use test.py with optical flow info from a *.flo file instead of *.pkl?

Thanks,

Questions about the usage of this project

Hi, I'm a new comer to the optical flow field, and I am attracted by this excellent project. As for as I know, the grid_sample function in pytorch solves the problem about backward flow warping. For example, if I have frame1 and frame2, I can use the flow1->2 to warp the information from frame2 to frame1. What about your codes? Does flow forward warping mean I can adopt the flow1->2 to warp the information from frame1 to frame2? I'm looking forward to your reply, thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.