Coder Social home page Coder Social logo

peteryux / esrgan-tf2 Goto Github PK

View Code? Open in Web Editor NEW
132.0 3.0 43.0 17.39 MB

ESRGAN (Enhanced Super-Resolution Generative Adversarial Networks, published in ECCV 2018) implemented in Tensorflow 2.0+. This is an unofficial implementation. With Colab.

License: MIT License

Python 100.00%
tensorflow super-resolution esrgan esrgan-tf2 tf2 gan perceptual-losses sr colab-notebook colab

esrgan-tf2's Introduction

Language grade: Python Star Fork License

Open In Colab

๐Ÿ”ฅ ESRGAN (Enhanced Super-Resolution Generative Adversarial Networks, published in ECCV 2018) implemented in Tensorflow 2.0+. This is an unofficial implementation. ๐Ÿ”ฅ

ESRGAN introduce the Residual-in-Residual Dense Block (RRDB) without batch normalization as the basic network building unit, the idea from relativistic GAN to let the discriminator predict relative realness, and the perceptual loss by using the features before activation. Benefiting from these improvements, the proposed ESRGAN achieves consistently better visual quality with more realistic and natural textures than SRGAN and won the first place in the PIRM2018-SR Challenge.

Original Paper: ย  Arxiv ย  ECCV2018

Offical Implementation: ย  PyTorch

:: Results from this reporepository. ::


Contents

๐Ÿ“‘


Installation

๐Ÿ•

Create a new python virtual environment by Anaconda or just use pip in your python environment and then clone this repository as following.

Clone this repo

git clone https://github.com/peteryuX/esrgan-tf2.git
cd esrgan-tf2

Conda

conda env create -f environment.yml
conda activate esrgan-tf2

Pip

pip install -r requirements.txt

Data Preparing

๐Ÿบ

All datasets used in this repository follow the official implement as same as possible. This code focus on implementation of x4 version.

Training Dataset

Step 1: Download the DIV2K GT images and corresponding LR images from the download links bellow.

Dataset Name Link
Ground-Truth DIV2K_train_HR
LRx4 (MATLAB bicubic) DIV2K_train_LR_bicubic_X4

Note: If you want to dowsample your traning data as LR images by yourself, you can use the imresize_np() wich is numpy implementation or MATLAB resize.

Step 2: Extract them into ./data/DIV2K/. The directory structure should be like bellow.

./data/DIV2K/
    -> DIV2K_valid_HR/
        -> 0001.png
        -> 0002.png
        -> ...
    -> DIV2K_train_LR_bicubic/
        -> X4/
            -> 0001x4.png
            -> 0002x4.png

Step 3: Rename and Crop to sub-images with the script bellow. Modify these scripts if you need other setting.

# rename image file in LR folder `DIV2K_train_LR_bicubic/*'.
python data/rename.py

# extract sub-images from HR folder and LR folder.
python data/extract_subimages.py

Step 4: Convert the sub-images to tfrecord file with the the script bellow.

# Binary Image (recommend): convert slow, but loading faster when traning.
python data/convert_train_tfrecord.py --output_path="./data/DIV2K800_sub_bin.tfrecord" --is_binary=True
# or
# Online Image Loading: convert fast, but loading slower when training.
python data/convert_train_tfrecord.py --output_path="./data/DIV2K800_sub.tfrecord" --is_binary=False

Note:

  • You can run python ./dataset_checker.py to check if the dataloader work.

Testing Dataset

Step 1: Download the common image SR datasets from the download links bellow. You only need Set5 and Set14 in the default setting ./configs/*.yaml.

Dataset Name Short Description Link
Set5 Set5 test dataset Google Drive
Set14 Set14 test dataset Google Drive
BSDS100 A subset (test) of BSD500 for testing Google Drive
Urban100 100 building images for testing (regular structures) Google Drive
Manga109 109 images of Japanese manga for testing Google Drive
Historical 10 gray LR images without the ground-truth Google Drive

Step 2: Extract them into ./data/. The directory structure should be like bellow. The directory structure should be like bellow.

./data/
    -> Set5/
        -> baby.png
        -> bird.png
        -> ...
    -> Set14/
        -> ...

Training and Testing

๐Ÿญ

Config File

You can modify your own dataset path or other settings of model in ./configs/*.yaml for training and testing, which like below.

# general setting
batch_size: 16
input_size: 32
gt_size: 128
ch_size: 3
scale: 4
sub_name: 'esrgan'
pretrain_name: 'psnr_pretrain'

# generator setting
network_G:
    nf: 64
    nb: 23
# discriminator setting
network_D:
    nf: 64

# dataset setting
train_dataset:
    path: './data/DIV2K800_sub_bin.tfrecord'
    num_samples: 32208
    using_bin: True
    using_flip: True
    using_rot: True
test_dataset:
    set5_path: './data/Set5'
    set14_path: './data/Set14'

# training setting
niter: 400000

lr_G: !!float 1e-4
lr_D: !!float 1e-4
lr_steps: [50000, 100000, 200000, 300000]
lr_rate: 0.5

adam_beta1_G: 0.9
adam_beta2_G: 0.99
adam_beta1_D: 0.9
adam_beta2_D: 0.99

w_pixel: !!float 1e-2
pixel_criterion: l1

w_feature: 1.0
feature_criterion: l1

w_gan: !!float 5e-3
gan_type: ragan  # gan | ragan

save_steps: 5000

Note:

  • The sub_name is the name of outputs directory used in checkpoints and logs folder. (make sure of setting it unique to other models)
  • The using_bin is used to choose the type of training data, which should be according to the data type you created in the Data-Preparing.
  • The w_pixel/w_feature/w_gan is the combined weight of pixel/feature/gan loss.
  • The save_steps is the number interval steps of saving checkpoint file.

Training

Pretrain PSNR

Pretrain the PSNR RDDB model by yourself, or dowload it from BenchmarkModels.

python train_psnr.py --cfg_path="./configs/psnr.yaml" --gpu=0

ESRGAN

Train the ESRGAN model with the pretrain PSNR model.

python train_esrgan.py --cfg_path="./configs/esrgan.yaml" --gpu=0

Note:

  • Make sure you have the pretrain PSNR model before train ESRGAN model. (Pretrain model checkpoint should be located at ./checkpoints for restoring)
  • The --gpu is used to choose the id of your avaliable GPU devices with CUDA_VISIBLE_DEVICES system varaible.
  • You can visualize the learning rate scheduling by running "python ./modules/lr_scheduler.py".

Testing

You can download my trained models for testing from Models without training it yourself. And, evaluate the models you got with the corresponding cfg file on the testing dataset. The visualizations results would be saved into ./results/.

# Test ESRGAN model
python test.py --cfg_path="./configs/esrgan.yaml"
# or
# PSNR pretrain model
python test.py --cfg_path="./configs/psnr.yaml"

SR Input Image

You can upsample your image by the SR model. For example, upsample the image from ./data/baboon.png as following.

python test.py --cfg_path="./configs/esrgan.yaml" --img_path="./data/baboon.png"
# or
# PSNR pretrain model
python test.py --cfg_path="./configs/psnr.yaml" --img_path="./data/baboon.png"

Network Interpolation

Produce the compare results between network interpolation and image interpolation as same as original paper.

python net_interp.py --cfg_path1="./configs/psnr.yaml" --cfg_path2="./configs/esrgan.yaml" --img_path="./data/PIPRM_3_crop.png" --save_image=True --save_ckpt=True

Note:

  • --save_image means save the compare results into ./results_interp.
  • --save_ckpt means save all the interpolation ckpt files into ./results_interp.

Benchmark and Visualization

โ˜•

Verification results (PSNR/SSIM) and visiualization results.

Set5

Image Name Bicubic PSNR (pretrain) ESRGAN Ground Truth
baby
31.96 / 0.85 33.86 / 0.89 31.36 / 0.83 -
bird
30.27 / 0.87 35.00 / 0.94 32.22 / 0.90 -
butterfly
22.25 / 0.72 28.56 / 0.92 26.66 / 0.88 -
head
32.01 / 0.76 33.18 / 0.80 30.19 / 0.70 -
woman
26.44 / 0.83 30.42 / 0.92 28.50 / 0.88 -

Set14 (Partial)

Image Name Bicubic PSNR (pretrain) ESRGAN Ground Truth
baboon
22.06 / 0.45 22.77 / 0.54 20.73 / 0.44 -
comic
21.69 / 0.59 23.46 / 0.74 21.08 / 0.64 -
lenna
29.67 / 0.80 32.06 / 0.85 28.96 / 0.80 -
monarch
27.60 / 0.88 33.27 / 0.94 31.49 / 0.92 -
zebra
24.15 / 0.68 27.29 / 0.78 24.86 / 0.67 -

Note:

  • The baseline Bicubic resizing method can be find at imresize_np().
  • All the PSNR and SSIM results are calculated on Y channel of YCbCr.
  • All results trained on DIV2K.

Network Interpolation (on ./data/PIPRM_3_crop.png)

weight interpolation

image interpolation

(ESRGAN <-> PSNR, alpha=[1., 0.8, 0.6, 0.4, 0.2, 0.])


Models

๐Ÿฉ

Model Name Download Link
PSNR GoogleDrive
ESRGAN GoogleDrive
PSNR (inference) GoogleDrive
ESRGAN (inference) GoogleDrive

Note:

  • After dowloading these models, extract them into ./checkpoints for restoring.
  • The inference version was saved without any tranning operator, which is smaller than the original version. However, if you want to finetune, the orginal version is more suitable.
  • All training settings of the models can be found in the corresponding ./configs/*.yaml files.
  • Based on the property of the training dataset, all the pre-trained models can only be used for non-commercial applications.

References

๐Ÿ”

Thanks for these source codes porviding me with knowledges to complete this repository.

esrgan-tf2's People

Contributors

peteryux avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

esrgan-tf2's Issues

update the gradient do not need freeze one of network?

when update the gradient , did not freeze g or d in code
`
@tf.function
def train_step(lr, hr):
with tf.GradientTape(persistent=True) as tape:
sr = generator(lr, training=True)
hr_output = discriminator(hr, training=True)
sr_output = discriminator(sr, training=True)

        losses_G = {}
        losses_D = {}
        losses_G['reg'] = tf.reduce_sum(generator.losses)
        losses_D['reg'] = tf.reduce_sum(discriminator.losses)
        losses_G['pixel'] = cfg['w_pixel'] * pixel_loss_fn(hr, sr)
        losses_G['feature'] = cfg['w_feature'] * fea_loss_fn(hr, sr)
        losses_G['gan'] = cfg['w_gan'] * gen_loss_fn(hr_output, sr_output)
        losses_D['gan'] = dis_loss_fn(hr_output, sr_output)
        total_loss_G = tf.add_n([l for l in losses_G.values()])
        total_loss_D = tf.add_n([l for l in losses_D.values()])
    grads_G = tape.gradient(
        total_loss_G, generator.trainable_variables)
    grads_D = tape.gradient(
        total_loss_D, discriminator.trainable_variables)
    optimizer_G.apply_gradients(
        zip(grads_G, generator.trainable_variables))
    optimizer_D.apply_gradients(
        zip(grads_D, discriminator.trainable_variables))

    return total_loss_G, total_loss_D, losses_G, losses_D

`

loss_maybe_wrong

Hi, I'm here again!becuase i maybe find the wrong code.
In the code.
def generator_loss(hr, sr): return cross_entropy(tf.ones_like(sr), sigma(sr))

I have a question that tf.one_like maybe hr, sigma(sr)

output SR images saved in comparison mode.

I am able to test the image but only with comparison mode, basically where the BiCu, ESRGAN, HR outputs are stitched.
I am not able to get a singular up-scaled image for LR input.

Speed

I have SR Resnet and the it faster than esrgan about x6 (speed).
How to enhanse the network to be faster? With some loss in quality.

Transfer learning

Hello @peteryuX ,
It's a nice work and appreciatable for converting from pytorch to tensorflow. I have a doubt that can we perform transfer learning in this ESRGAN model for the custom dataset, because of Training issues?
Thanks in advance.

Quality

After full training, my result is not good. SRGAN trained by a colleague is better than the trained by me on ESRGAN. Why I see artifacts on images? What I have to change?
Building

Loss is always nan

Hello PeteryuX,

Thanks a lot for sharing your implementation of ESRGAN.

I have been testing some of the GAN based superresolution network recently. I have got a lot of training HR/LR images and would like to train the ESRGAN (PSNR+ESRGAN) network using your training code.

I have followed your instructions on data preparation and converted my 1,825,587 pairs of LR/HR samples to *bin.tfrecord checked dataset_checker no problem, LR/HR images displayed well, modified few lines of your code for the hardcoded paths etc. and started PSNR training on the RTX3090 GPU. However, the calculated and printed out "loss" is always "nan" in every iteration, and even after "successfully" finished PSNR training, the loss_D and loss_G in ESRGAN training is also shown as "nan".

in psnr training:
...
Training [>> ] 20004/600000, loss=nan, lr=2.0e-04 2.0 step/sec
...

in esrgan training:
...
Training [>>> ] 40000/285240, loss_G=nan, loss_D=nan, lr_G=1.0e-04, lr_D=1.0e-04 1.4 step/sec
[*] save ckpt file at ./checkpoints/esrgan/ckpt-32
Training [>>>> ] 47877/285240, loss_G=nan, loss_D=nan, lr_G=1.0e-04, lr_D=1.0e-04 1.4 step/sec
...

Do you have any suggestions on this issue?

I here attach the psnr+esrgan parameter files:

psnr.yaml:
batch_size: 64
input_size: 32
gt_size: 128
ch_size: 3
scale: 4
sub_name: 'psnr_pretrain'
pretrain_name: null

network_G:
nf: 64
nb: 23

train_dataset:
path: '/data/EOSC/EOSC_sub_bin.tfrecord'
num_samples: 1825587
using_bin: True
using_flip: True
using_rot: True
test_dataset:
EOSC_path: '/data2/EOSC_test'

niter: 600000
lr: !!float 2e-4
lr_steps: [200000, 300000, 400000, 500000]
lr_rate: 0.5

adam_beta1_G: 0.9
adam_beta2_G: 0.99

w_pixel: 1.0
pixel_criterion: l1
save_steps: 20000

esrgan.yaml:
batch_size: 64
input_size: 32
gt_size: 128
ch_size: 3
scale: 4
sub_name: 'esrgan'
pretrain_name: 'psnr_pretrain'

network_G:
nf: 64
nb: 23
network_D:
nf: 64

train_dataset:
path: '/data/EOSC/EOSC_sub_bin.tfrecord'
num_samples: 1825587
using_bin: True
using_flip: False
using_rot: False
test_dataset:
EOSC_path: '/data2/EOSC_test'

niter: 285240
lr_G: !!float 1e-4
lr_D: !!float 1e-4
lr_steps: [60000, 120000, 180000, 240000]
lr_rate: 0.5

adam_beta1_G: 0.9
adam_beta2_G: 0.99
adam_beta1_D: 0.9
adam_beta2_D: 0.99

w_pixel: !!float 1e-2
pixel_criterion: l1

w_feature: 1.0
feature_criterion: l1

w_gan: !!float 5e-3
gan_type: ragan # gan | ragan

save_steps: 20000

Any help would be much appreciated! Thank you!

NOT an issue: It works amazingly well. However, use Spectral Normalization will converge VERY fast and nice !

Your codes works amazingly well. However, use Spectral Normalization will converge VERY fast and nice !

That is all. I changed 2 things: (1) Added Spectral Normalization to Conv2D layers of Discriminator network (except the first Conv2D layer) and (2) Replace VGG19 by VGGface for my problem of scale-up faces those are not very clear for a better face recognition/triplet loss.

Thanks mate.
Steve

[email protected]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.