Coder Social home page Coder Social logo

ms_ssim_pytorch's Introduction

ms_ssim_pytorch

The code was modified from https://github.com/VainF/pytorch-msssim.
Part of the code has been modified to make it faster, takes up less VRAM, and is compatible with pytorch jit.

The dynamic channel version can found here https://github.com/One-sixth/ms_ssim_pytorch/tree/dynamic_channel_num.
More convenient to use but has a little performance loss.

Thanks vegetable09 for finding and fixing a bug that causes gradient nan when ms_ssim backward. #3

If you are using pytorch 1.2, please be careful not to create and destroy this jit module in the training loop (other jit modules may also have this situation), there may be memory leaks. I have tested that pytorch 1.6 does not have this problem. #4

I study to the ssim.py of the library piqa, which makes my implementation of ssim and ms-ssim a little faster than before.

Speed up. Only test on GPU.

losser1 is https://github.com/lizhengwei1992/MS_SSIM_pytorch/blob/master/loss.py 268fc76
losser2 is https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 881d210
losser3 is https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py 5caf547
losser4 is https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py 1c2f14a
losser5 is https://github.com/francois-rozet/piqa/blob/master/piqa/ssim.py abaf398
https://github.com/francois-rozet/piqa/blob/master/piqa/utils.py abaf398

In pytorch 1.7.1
My test environment: i7-8750H GTX1070-8G

In pytorch 1.1 1.2
My test environment: i7-6700HQ GTX970M-3G

SSIM

Test output

pytorch 1.7.1

Performance Testing SSIM

testing losser2
cuda time 39963.15625
perf_counter time 35.9110169

testing losser3
cuda time 17141.841796875
perf_counter time 17.124456199999997

testing losser4
cuda time 13205.0322265625
perf_counter time 10.477991699999997

testing losser5
cuda time 13142.8232421875
perf_counter time 11.079514100000011

pytorch 1.2

Performance Testing SSIM

testing losser2
cuda time 89290.7734375
perf_counter time 87.1042247

testing losser3
cuda time 36153.64453125
perf_counter time 36.09167939999999

testing losser4
cuda time 31085.455078125
perf_counter time 29.80807200000001

pytorch 1.1

Performance Testing SSIM

testing losser2
cuda time 88990.0703125
perf_counter time 86.80163019999999

testing losser3
cuda time 36119.06640625
perf_counter time 36.057978399999996

testing losser4
cuda time 34708.8359375
perf_counter time 33.916086199999995

MS-SSIM

Test output

pytorch 1.7.1

Performance Testing MS_SSIM

testing losser1
cuda time 60403.59765625
perf_counter time 60.351266200000005

testing losser3
cuda time 26321.48828125
perf_counter time 26.30165939999999

testing losser4
cuda time 24471.6875
perf_counter time 24.45189119999999

testing losser5
cuda time 23153.962890625
perf_counter time 23.135575399999993

pytorch 1.2

Performance Testing MS_SSIM

testing losser1
cuda time 134158.84375
perf_counter time 134.0433756

testing losser3
cuda time 62143.4140625
perf_counter time 62.103911400000015

testing losser4
cuda time 46854.25390625
perf_counter time 46.81785239999999

pytorch 1.1

Performance Testing MS_SSIM

testing losser1
cuda time 134115.96875
perf_counter time 134.0006031

testing losser3
cuda time 61760.56640625
perf_counter time 61.71994470000001

testing losser4
cuda time 52888.03125
perf_counter time 52.848280500000016

Test speed by yourself

  1. cd ms_ssim_pytorch/_test_speed

  2. python test_ssim_speed.py
    or

  3. python test_ms_ssim_speed.py

Other thing

Add parameter use_padding.
When set to True, the gaussian_filter behavior is the same as https://github.com/Po-Hsun-Su/pytorch-ssim.
This parameter is mainly used for MS-SSIM, because MS-SSIM needs to be downsampled.
When the input image is smaller than 176x176, this parameter needs to be set to True to ensure that MS-SSIM works normally. (when parameter weight and level are the default)

Require

Pytorch >= 1.1

if you want to test the code with animation. You also need to install some package.

pip install imageio imageio-ffmpeg opencv-python

Test code with animation

The test code is included in the ssim.py file, you can run the file directly to start the test.

  1. git clone https://github.com/One-sixth/ms_ssim_pytorch
  2. cd ms_ssim_pytorch
  3. python ssim.py

Code Example.

import torch
import ssim


im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
img1 = im / 255
img2 = img1 * 0.5

losser = ssim.SSIM(data_range=1., channel=3).cuda()
loss = losser(img1, img2).mean()

losser2 = ssim.MS_SSIM(data_range=1., channel=3).cuda()
loss2 = losser2(img1, img2).mean()

print(loss.item())
print(loss2.item())

Animation

GIF is a bit big. Loading may take some time.
Or you can download the mkv video file directly to view it, smaller and smoother.
https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim_test.mkv
https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ms_ssim_test.mkv

SSIM
ssim

MS-SSIM
ms-ssim

References

https://github.com/VainF/pytorch-msssim
https://github.com/Po-Hsun-Su/pytorch-ssim
https://github.com/lizhengwei1992/MS_SSIM_pytorch
https://github.com/francois-rozet/piqa

ms_ssim_pytorch's People

Contributors

one-sixth avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

ms_ssim_pytorch's Issues

try clamp_min(0.001) to avoid nan when backward() in early step

thanks for your great effort to speed up the msssim.

I found nan when backward in very early iter when training. this is because

cs**a => 1/cx**(a-1) when backward

in the early step, the cs is very small, so 1/cs is inf.

I noticed that you use clamp_min(0.) in the msssim.py. I think replace 0 with 0.00001 will fix the nan problem when early train.


        ssim_val = ssim_val.clamp_min(0.0001) # avoid 1/x**(a-1) to be inf when x is very close to zero
        cs = cs.clamp_min(0.0001)

please correct me if i am wrong

实际运行中,容易内存爆掉

感谢大佬的杰出贡献,速度上确实提升不少。但是在实际运行的时候,我遇到了一些问题。我在我的网络中插入您的SSIMloss,运行代码后,每次都在运行到一半的时候(开始可以运行),内存直接占满爆掉。无论是服务器还是自己的电脑,都是这样。在修改了dataloader的各种参数后,还是出现这种情况。后来我将SSIMloss改为您代码中的loss3,这种情况就没有了。我不太了解您代码的运行情况,但是在实际运行中确实遇到了这个问题,所以向您反馈一下,不知道原因在哪里。

Add dynamic channel expansion.

Hello. I have been looking at the current implementation and I have found that it uses .repeat() on CPU.

However, I would like to propose using .expand() inside the ssim() function. This has several advantages.

First, this will significantly reduce the amount of data passing from CPU to GPU, which is a major bottleneck in CUDA programming. .repeat() copies data. .expand() does not copy data. So there is not much overhead in using .expand() after transferring the kernel to GPU.

Second, this will allow one to remove the 'channel' parameter from SSIM(), which will remove an unnecessary parameter.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.