Coder Social home page Coder Social logo

srgan_wasserstein's Introduction

SRGAN_Wasserstein

Applying Waseerstein GAN to SRGAN, a GAN based super resolution algorithm.

This repo was forked from @zsdonghao 's tensorlayer/srgan repo, based on this original repo, I changed some code to apply wasserstein loss, making the training procedure more stable, thanks @zsdonghao again, for his great reimplementation.

SRGAN Architecture

TensorFlow Implementation of "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"

Wasserstein GAN

When the SRGAN was first proposed in 2016, we haven't had Wasserstein GAN(2017) yet, WGAN using wasserstein distance to measure the disturibution difference between different data set. As for the original GAN training, we don't know when to stop training the discriminator or the generator, to get a nice result. But when using the wasserstein loss, as the loss decreasing, the result will be better. So we are going to use the WGAN and we are not going to explain the math detail of WGAN here, but to give the following steps to apply WGAN.

  • Remove the sigmoid activation from the last layer of the discriminator. (model.py, line 218-219)
  • Don't take logarithm to the loss of discriminator and generator. (main.py, line 105-108)
  • Clipping the weights to some contant range [-c, c]. (main.py, line 136)
  • Don't use the optimizer like adam or momoentum which based on momentum, instead, RMSprop or SGD would be better. (main.py, line 132-133)

These above steps was given by an excellent article[4], the arthor explained the WGAN in a very straightforward way, it was written in Chinese.

Loss curve and Result

Prepare Data and Pre-trained VGG

    1. You need to download the pretrained VGG19 model in here as tutorial_vgg19.py show.
    1. You need to have the high resolution images for training.
    • In this experiment, I used images from DIV2K - bicubic downscaling x4 competition, so the hyper-paremeters in config.py (like number of epochs) are seleted basic on that dataset, if you change a larger dataset you can reduce the number of epochs.
    • If you dont want to use DIV2K dataset, you can also use Yahoo MirFlickr25k, just simply download it using train_hr_imgs = tl.files.load_flickr25k_dataset(tag=None) in main.py.
    • If you want to use your own images, you can set the path to your image folder via config.TRAIN.hr_img_path in config.py.

Run

We run this script under TensorFlow 1.4 and the TensorLayer 1.8.0+.

  • Installation
pip install tensorlayer==1.8.0
conda install tensorflow-gpu==1.3.0
pip install tensorflow-gpu==1.4.0
pip install easydict
config.TRAIN.img_path = "your_image_folder/"
  • Tenserboard logdir.

I added the tensorboard callbacks to monitor the training procedure, please change the logdir to your folder.

config.VALID.logdir = 'your_tensorboard_folder'
  • Start training.
python main.py
  • Start evaluation. (pretrained model for DIV2K) An important note: This pretrained weights is provided by the original author @zsdonghao , his final layer's conv kernel of SRGAN_g (model.py line 53) is using 1×1 kernel, but I changed this kernel to 9×9, so if you use this pretrained weights, you may get the weights unequal error. Two advice: 1)Train the whole network from scratch, you'll get the 9×9 version weights, for further training or evaluating images. 2)You can just change the SRGAN_g 's final conv kernel (model.py line 53) to (1, 1) instead of (9, 9), and change the model.py line 35 conv kernel from (9, 9) to (3, 3), so that you can use the pretrained weights.
python main.py --mode=evaluate 

What's new?

Compare with the original version, I did the following changes:

  1. Adding WGAN, as described in Wasserstein GAN chapter.
  2. Adding tensorboard, to monitor the training procedure.
  3. Modified the last conv layer of 'SRGAN_g' in model.py (line 100), changing the kernel size from (1, 1) to (9, 9), as the paper proposed.

Reference

Author

License

  • For academic and non-commercial use only.
  • For commercial use, please contact [email protected].

srgan_wasserstein's People

Contributors

justinhochn avatar

Watchers

wangyangyang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.