Coder Social home page Coder Social logo

deep-generalized-unfolding-networks-for-image-restoration's Introduction

Deep Generalized Unfolding Networks for Image Restoration (CVPR 2022)

Chong Mou, Qian Wang, Jian Zhang

Paper

Abstract: Deep neural networks (DNN) have achieved great success in image restoration. However, most DNN methods are designed as a black box, lacking transparency and interpretability. Although some methods are proposed to combine traditional optimization algorithms with DNN, they usually demand pre-defined degradation processes or handcrafted assumptions, making it difficult to deal with complex and real-world applications. In this paper, we propose a Deep Generalized Unfolding Network (DGUNet) for image restoration. Concretely, without loss of interpretability, we integrate a gradient estimation strategy into the gradient descent step of the Proximal Gradient Descent (PGD) algorithm, driving it to deal with complex and real-world image degradation. In addition, we design inter-stage information pathways across proximal mapping in different PGD iterations to rectify the intrinsic information loss in most deep unfolding networks (DUN) through a multi-scale and spatial-adaptive way. By integrating the flexible gradient descent and informative proximal mapping, we unfold the iterative PGD algorithm into a trainable DNN. Extensive experiments on various image restoration tasks demonstrate the superiority of our method in terms of state-of-the-art performance, interpretability, and generalizability.

🔥 Network Architecture

Network

🎨 Applications

🚩Deblurring🚩

blur

🚩Deraining🚩

rain rain

🚩Denoising🚩

noise

🚩Compressive Sensing🚩

noise

🔧 Installation

The model is built in PyTorch 1.1.0 and tested on Ubuntu 16.04 environment (Python3.7, CUDA9.0, cuDNN7.5). The model is trained with 2 NVIDIA V100 GPUs.

For installing, follow these intructions

conda create -n pytorch1 python=3.7
conda activate pytorch1
conda install pytorch=1.1 torchvision=0.3 cudatoolkit=9.0 -c pytorch
pip install matplotlib scikit-image opencv-python yacs joblib natsort h5py tqdm

Install warmup scheduler

cd pytorch-gradual-warmup-lr; python setup.py install; cd ..

💻 Training and Evaluation

Training and Testing codes for deblurring, deraining, denoising and compressive sensing are provided in their respective directories.

🏰 Model Zoo

For Deblurring, Deraining, Denoising

Please download checkpoints from Google Drive.

For Compressive Sensing

Please download checkpoints from Google Drive.

📑 Citation

If you use DGUNet, please consider citing:

@inproceedings{Mou2022DGUNet,
    title={Deep Generalized Unfolding Networks for Image Restoration},
    author={Chong Mou and Qian Wang and Jian Zhang},
    booktitle={CVPR},
    year={2022}
}

📧 Contact

If you have any question, please email [email protected].

🤗 Acknowledgements

This code is built on MPRNet (PyTorch). We thank the authors for sharing their codes of MPRNet.

deep-generalized-unfolding-networks-for-image-restoration's People

Contributors

mc-e avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

deep-generalized-unfolding-networks-for-image-restoration's Issues

The model you provided failed to run

When I reproduced the results of the paper, "python test_DND.py --save_image", the following error occurred.
image

Please make sure that the model you provided is trained using torch1.1? If not, please provide the correct environment configuration.

Issue about trainable measurement matrix A in compressed sensing mode

Hi there @MC-E ,

Thank you for your great input and sharing the code! I have a question about the compressed sensing case:

  1. As you mentioned in the paper: "Note that in the task of compressive sensing, the degradation matrix A is exactly known, i.e., the sampling matrix Φ. Thus, we directly use Φ to calculate the gradient.". However, in the code, you set A to be trainable parameters instead:

     PhiTPhi = torch.mm(torch.transpose(self.Phi, 0, 1), self.Phi)  # torch.mm(Phix, Phi)
     Phix = torch.mm(img.view(b,-1), torch.transpose(self.Phi, 0, 1))  # compression result
     PhiTb = torch.mm(Phix,self.Phi)
     # compute r_0
     x_0=PhiTb.view(b,-1)
     x = x_0 - self.r0 * torch.mm(x_0, PhiTPhi)
     r_0 = x + self.r0 * PhiTb
     r_0=r_0.view(b,c,w,h)
    

https://github.com/MC-E/Deep-Generalized-Unfolding-Networks-for-Image-Restoration/blob/bae845c2612d0df56a479020d59896441168d07a/Compressive-Sensing/DGUNet.py#L384C1-L391C30

where the self.Phi (referred as A) is learnable in the training. This makes me confused, because in compressed sensing, we assume that the only information we have are y and A, and we have no access to the raw image X_0. But here since A is learnable params, y is essentially a linear transformation of the real X_0, which means all the information of X_0 is known as the input of the model. Eventually, the model is actually learning \hat{X} (output) given the real image X_0 (input), which is somehow equivalent to a problem of recovering X_0 given X_0.

Instead, since you assume A is unknown, then the process of making measurement y should not involve the learnable parameter A, which is the process of getting Phix in the code.

  1. What is more, the input of the model when testing is the real image (say X_0):

         batch_x = torch.from_numpy(Img_output)
         batch_x = batch_x.type(torch.FloatTensor)
         batch_x = batch_x.to(device)
       # Phix = torch.mm(batch_x, torch.transpose(Phi, 0, 1))  # compression result
       # PhixPhiT = torch.mm(Phix, Phi)
         batch_x = batch_x.view(batch_x.shape[0], 1, args.patch_size, args.patch_size)
         x_output = model(batch_x)[0]  # torch.mm(batch_x,
    

https://github.com/MC-E/Deep-Generalized-Unfolding-Networks-for-Image-Restoration/blob/bae845c2612d0df56a479020d59896441168d07a/Compressive-Sensing/train.py#L231C1-L238C62

and the measurement y is obtained by y=AX_0, where the A is the trainable parameters. I couldn't understand this setting, since in the testing case, we assume that the only information we have is y and A (if we know the degradation model), but here the input of model is the real raw testing image.

Please correct me if I misunderstood anything here, and I apologize in advance if I missed anything or misunderstood anything that is already explained clearly in the paper and code. Thank you so much, and l look forward to your replying!

训练流程

你好,感谢你提供的源码。为什么训练的时候要分train和train_deblock这两步,用意是什么?期待你的回答,谢谢。

About Loss

First,it is a real awesome work!
When I run the train.py in the file named Deblurring, I got an error:

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
It seems that you use the numpy.sum () to compute the loss after CharbonnierLoss() and EdgeLoss(). As far as I know, if we use the loss.backward(), we should use the tools provided by pytorch(I don't know if it is true).
The code you provided is below:
loss_char = np.sum([criterion_char(restored[j],target) for j in range(len(restored))]) loss_edge = np.sum([criterion_edge(restored[j], target) for j in range(len(restored))]) loss = (loss_char) + (0.05*loss_edge)
When I use the below code to replace the above, I found the psnr will not change during training .
loss_char = torch.tensor([criterion_char(restored[j], target) for j in range(len(restored))],requires_grad=True).sum() loss_edge = torch.tensor([criterion_edge(restored[j], target) for j in range(len(restored))],requires_grad=True).sum() loss = (loss_char) + (0.05*loss_edge)
Could you help me to train it correctly? Looking forward to your reply!

compressive sensing数据集与论文中有差异

您好!

  • 您论文中指出在compressive sensing任务中,只用了BSD400作为训练集

or this application, we choose the widely used BSD400 dataset [43] as the training data and evaluate each method on Set11 [35] and BSD68 [43] test sets.

  • 但是您实际代码中提供的训练集确是DIV2K800+BSD400共计1200幅图像的混合数据集
  • 那么您文章中的所有对比方法的PSNR指标结果也是用混合训练集再训练的结果吗?
  • 您是否用混合数据集对其它对比方法进行了二次训练?

The training set for training compressive sensing task

Hi, thanks for your efforts. I reproduced your work with the BSD400 dataset for the compressive sensing task. However, I trained for several days, but the results are far from your reported results. Is there any trick for the training? Thanks very much! (The model parameters are large, while the training images are few. )

训练资源问题

请问能咨询一下用的两张v100是那种显存规格的呢,同时想问一下例如denoising任务中,训练过程大概会需要多少的显存占用量

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.