Coder Social home page Coder Social logo

manuelfritsche / real-world-sr Goto Github PK

View Code? Open in Web Editor NEW
161.0 6.0 37.0 949.63 MB

[ICCVW 2019] PyTorch implementation of DSGAN and ESRGAN-FS from the paper "Frequency Separation for Real-World Super-Resolution". This code was the winning solution of the AIM challenge on Real-World Super-Resolution at ICCV 2019

License: MIT License

Python 94.41% Shell 0.11% MATLAB 5.48%

real-world-sr's Issues

Files for DSGAN SDSR pretrained model broken

Thanks author for the great work. Somehow the compressed files of the pre-trained models for DSGAN SDSR model (DF2K_Gaussian.tar etc.) are broken. Any one have the unbroken tar file pls? Many thanks if author can re-upload again!

Trained Models

Hi Manuel,

I was unable to locate the trained models for dsgan. I wanted to check if you have already shared the trained models or plan to share?

Thanks,
Touqeer

WIll I need to retrain DSGAN for 2x SR?

I want to train ESRGAN-FS for super-resolution with scale factor 2. To generate the dataset will I need to retrain DSGAN accordingly or can I use the pretrained DSGAN you provide to translate lets say 2x downsampled images to the target domain? I must say, the results from your models are very good even with blind sr with random images. I was also wondering if you already made some tests for 2x and if so, what kind of performance you saw?

Also, in dsgan/createdataset.py, I understand that the for loop in line 76-112 does what you explain section 4.1 Dataset Generation of your paper. But I fail to understand what the next for loop in lines 114-141 does. I am assuming there's a typo in the comment https://github.com/ManuelFritsche/real-world-sr/blob/4be7851da7decba5a005b1f0feee02a3c315b01e/dsgan/create_dataset.py#L137 but it would still be very helpful if you could explain what's happening there. Would be grateful :).
Best,
Samim

issue in DSGAN part about the UNsupervied training

Hello Manuel,

recently, I am following your work FSSR. While changing the structure of the DSGAN, I got some doubt about the unsupervised training as follows.

image
As the architecture dipicted in the paper, the unsupervised part of DSGAN lies in the Discriminator part ---- to tell a generated LR image from a real world LR image.

But When reading the code of train.py, I find that when dealing with any dataset rather than aim2019, the train_set is defined as follows
train_set = loader.TrainDataset(PATHS[opt.dataset][opt.artifacts]['hr']['train'], cropped=True, **vars(opt))

And I roughly draw a pipeline of the management of the data
image

So, in my mind, it seems that you haven't use the real-world LR image as disc_img in the backward part. so it can't count as unsupervised training?

Is my understanding correct?

Futhermore, If i want to have the real-world LR image envoled in the train_dataset, Can I just change the defination of the dataset the same way the Validation dataset is defined?(output 3 imgs: bicubiced_img, img downscaled by generator and real LR image)

And by the way, I want to assure that the output of utils.imresize(img) is the bicubic downsampled format of img?
And what is the short word disc (in disc_img ) mean? (Just for personal interest.)

Thx

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 256, 1, 1]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

When I want to run the DSGAN/train.py with python train.py have the top question,who can help me solve this problem,thank you.
And I have try the ways in some blogs said ,set the relu and leakyrelu the inplace is False,but the problem doesn't solved~
This is the code of model:
`from torch import nn
import torch

class Generator(nn.Module):
def init(self, n_res_blocks=8):
super(Generator, self).init()
self.block_input = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.PReLU()
)
self.res_blocks = nn.ModuleList([ResidualBlock(64) for _ in range(n_res_blocks)])
self.block_output = nn.Conv2d(64, 3, kernel_size=3, padding=1)

def forward(self, x):
    block = self.block_input(x)
    for res_block in self.res_blocks:
        block = res_block(block)
    block = self.block_output(block)
    return torch.sigmoid(block)

class Discriminator(nn.Module):
def init(self, recursions=1, stride=1, kernel_size=5, gaussian=False, wgan=False, highpass=True):
super(Discriminator, self).init()
if highpass:
self.filter = FilterHigh(recursions=recursions, stride=stride, kernel_size=kernel_size, include_pad=False,
gaussian=gaussian)
else:
self.filter = None
self.net = DiscriminatorBasic(n_input_channels=3)
self.wgan = wgan

def forward(self, x, y=None):
    if self.filter is not None:
        x = self.filter(x)
    x = self.net(x)
    if y is not None:
        x -= self.net(self.filter(y)).mean(0, keepdim=True)
    if not self.wgan:
        x = torch.sigmoid(x)
    return x

class DiscriminatorBasic(nn.Module):
def init(self, n_input_channels=3):
super(DiscriminatorBasic, self).init()
self.net = nn.Sequential(
nn.Conv2d(n_input_channels, 64, kernel_size=5, padding=2),
nn.LeakyReLU(0.2,inplace=False),

        nn.Conv2d(64, 128, kernel_size=5, padding=2),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2,inplace=False),

        nn.Conv2d(128, 256, kernel_size=5, padding=2),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2,inplace=False),

        nn.Conv2d(256, 1, kernel_size=1)
    )

def forward(self, x):
    return self.net(x)

class ResidualBlock(nn.Module):
def init(self, channels):
super(ResidualBlock, self).init()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

def forward(self, x):
    residual = self.conv1(x)
    residual = self.prelu(residual)
    residual = self.conv2(residual)
    return x + residual

class GaussianFilter(nn.Module):
def init(self, kernel_size=5, stride=1, padding=4):
super(GaussianFilter, self).init()
# initialize guassian kernel
mean = (kernel_size - 1) / 2.0
variance = (kernel_size / 6.0) ** 2.0
# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
x_coord = torch.arange(kernel_size)
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
y_grid = x_grid.t()
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()

    # Calculate the 2-dimensional gaussian kernel
    gaussian_kernel = torch.exp(-torch.sum((xy_grid - mean) ** 2., dim=-1) / (2 * variance))

    # Make sure sum of values in gaussian kernel equals 1.
    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

    # Reshape to 2d depthwise convolutional weight
    gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
    gaussian_kernel = gaussian_kernel.repeat(3, 1, 1, 1)

    # create gaussian filter as convolutional layer
    self.gaussian_filter = nn.Conv2d(3, 3, kernel_size, stride=stride, padding=padding, groups=3, bias=False)
    self.gaussian_filter.weight.data = gaussian_kernel
    self.gaussian_filter.weight.requires_grad = False

def forward(self, x):
    return self.gaussian_filter(x)

class FilterLow(nn.Module):
def init(self, recursions=1, kernel_size=5, stride=1, padding=True, include_pad=True, gaussian=False):
super(FilterLow, self).init()
if padding:
pad = int((kernel_size - 1) / 2)
else:
pad = 0
if gaussian:
self.filter = GaussianFilter(kernel_size=kernel_size, stride=stride, padding=pad)
else:
self.filter = nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=pad, count_include_pad=include_pad)
self.recursions = recursions

def forward(self, img):
    for i in range(self.recursions):
        img = self.filter(img)
    return img

class FilterHigh(nn.Module):
def init(self, recursions=1, kernel_size=5, stride=1, include_pad=True, normalize=True, gaussian=False):
super(FilterHigh, self).init()
self.filter_low = FilterLow(recursions=1, kernel_size=kernel_size, stride=stride, include_pad=include_pad,
gaussian=gaussian)
self.recursions = recursions
self.normalize = normalize

def forward(self, img):
    if self.recursions > 1:
        for i in range(self.recursions - 1):
            img = self.filter_low(img)
    img = img - self.filter_low(img)
    if self.normalize:
        return 0.5 + img * 0.5
    else:
        return img

`

image

no attribute PerceptualLoss

Hi Manuel, I was trying to following your example to train dsgan but I am encountering the following error:

"module 'PerceptualSimilarity' has no attribute 'PerceptualLoss'"

when it tries to initialize the PerceptualLoss in loss.py. I have downloaded the PerceptualSimilarity and placed it in the directory 'real-world-sr/dsgan/'. I was able to successfully run the examples provided with the PerceptualSimilarity author. Can you please advise?

Thanks,
Touqeer

As for PerceptualSimilarity

Hi, I had read and tried "from PerceptualSimilarity import models" to figure out the problem.
But it showed
ImportError: cannot import name 'models'

I will really appreciate that if you reply.

LR image generation

Hi Fritsche:
I've got a question about training the DSGAN from SCRATCH,
whether the LR images are directly generated by bicubic downsampling , and if crop is used.

Thanks

Where is the source images Z for Discriminator?

Hello.
I am trying to train a generator using the provided code with Div2k dataset.
However, I don't know where do I need to input the source images Z for the Discriminator.
If some part of code is dealing with this, may I know where is it?
Thank you very much.

Trouble training ESRGAN with jpeg artifact images.

Hello! Thank you for posting your code on GitHub. I successfully recreated the first part of your method (training DSGAN). Specifically I trained for jpeg artifacts and generated a paired dataset for the SR network (ESRGAN-FS). The generated dataset looks convincing, DSGAN was able to generate jpeg artifacts that look real. Then I trained ESRGAN-FS with your given codes/options/jpeg/train_TDSR.yml config file (all settings unchanged and placed the RRDB_ESRGAN_x4.pth pretrained model in the appropriate folder).

However, I am getting very weird results on the validation set. The images look cartoonish and the details are missing. This is for the TDSR case (50000 iteration), where I am trying to get clean HR images from generated LR images from DSGAN with jpeg artifacts.
ESRGAN output

Now, the weird thing is that, I even checked the supervised training where LR images are real images with jpeg artifacts (not generated) and still get the same cartoonish ouput. So the problem is with ESRGAN-FS, not DSGAN generated dataset. In addition, I also train the original ESRGAN (from BasicSR repo), and the got the same problem.

Can you help me with this? Maybe the issue is with the training configs (learning rate, loss weights) in the options files?

Question on TrainDataset in dsgan

Hi Manuel Fritsche.
Thanks for open source code. I have read your paper, and the figure 3 in original paper shows the structure to generate LR with natural characteristics. We can clearly see that the input of Gd is bicubic downsampled LR image, and the input of Dd is the output of Gd and source domain image. But when I see dsgan/train.py in detail, I find that only source domain images are used during the training of DSGAN. This is not consistent with figure 3 in original paper. Can you explain the reason?

There is no if __name__ == '__main__': in train.py

when I run "python train.py --dataset dataset --artifacts artifacts" get some problem.
Hope the author can help me.
Thank you!

    This probably means that you are not using fork to start your
    child processes and you have forgotten to use the proper idiom
    in the main module:

        if __name__ == '__main__':
            freeze_support()
            ...

    The "freeze_support()" line can be omitted if the program
    is not going to be frozen to produce an executable.

Loss Curves for Training

Hi Manuel,

Would it be possible for you to provide sample loss curves that one can load in Tensorboard and use as guidance?

I am trying to train DSGAN on my own data in the SDSR context where I do not add any additional degradation to the original frames and wan to have the LR frames same characteristics as the source. My understanding is that the LR frames generated through DSGAN once trained, should be crispier than the LR frames generated through bicubic downsampling. Am I correct in my understanding?

Currently I see the texture and perceptual losses go up and the color loss go down and overall loss going down -- I am unclear if this is an expected behavior. That is why requesting for the sample loss curves from your trainings.

Thanks,
Touqeer

_pickle.PicklingError: Can't pickle <function <lambda> at 0x7f308810e2f0>: attribute lookup <lambda> on __main__ failed

_pickle.PicklingError: Can't pickle <function at 0x7f308810e2f0>: attribute lookup on main failed
the scheduler_g.state_dict() is not right for 'torch.save(state_dict, path)'
state_dict = {
'epoch': epoch,
'iteration': iteration,
'model_g_state_dict': model_g.state_dict(),
'models_d_state_dict': model_d.state_dict(),
'optimizer_g_state_dict': optimizer_g.state_dict(),
'optimizer_d_state_dict': optimizer_d.state_dict(),
'scheduler_g_state_dict': scheduler_g.state_dict(),
'scheduler_d_state_dict': scheduler_d.state_dict(),
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.