Coder Social home page Coder Social logo

karoly-hars / gan_image_colorizing Goto Github PK

View Code? Open in Web Editor NEW
11.0 3.0 4.0 1.76 MB

Image colorization with generative adversarial networks on the CIFAR10 dataset.

License: MIT License

Python 100.00%
gan generative-adversarial-network deep-learning pytorch image-colorization cifar10 unet spectral-normalization batch-normalization computer-vision

gan_image_colorizing's Introduction

Image colorization with GANs

The aim of this project is to explore the topic of image colorization with the help of Generative Adversarial Networks.

Approach

The project heavily builds on the findings of the paper Image Colorization with Generative Adversarial Networks by Nazeri et al. Similarly to the article, I studied image colorization on the CIFAR10 dataset using adversarial learning. Additionally, I experimented with spectral normalization in order to stabilize the training of the discriminator and reduce the instability of the whole training procedure.

Prerequisites

  • numpy
  • torch (only tested with 1.0.0)
  • opencv-python

Running the code

Training using regular batch normalization:

python3 train.py --max_epoch=200 --smoothing=0.9 --l1_weight=0.99 --base_lr_gen=3e-4 \
--base_lr_disc=6e-5 --lr_decay_steps=6e4 --disc_norm=batch --apply_weight_init=1

Training using spectral normalization:

python3 train.py --max_epoch=100 --smoothing=1.0 --l1_weight=0.985 --base_lr_gen=2e-4 \
--base_lr_disc=2e-4 --lr_decay_steps=4e4 --disc_norm=spectral --apply_weight_init=0

During training, a sample of images is saved after every epoch.

To go through to CIFAR10 test set and display/save some re-colorized images:

python3 test.py

Results

Utilizing batch normalization in both networks resulted in vivid, colorful images. This is often an advantage when the sample images contain cars, trucks, or other objects with sharper colors. On the other hand, it can lead to over-colorization on blander images.

As expected, using spectral normalization in the discriminator network stabilized the training process and reduced the number of required training steps. The results show that this approach produces better and more believable colors. However, when the input grayscale image contains an object with a number of possible colors (for example cars), the generator network often produces grayish outputs.

Qualitative evaluation

(left to right: original RGB, grayscale input, the output of the generator from the batch norm training, the output of the generator from the spectral norm training)

Some test examples where spectral normalization is superior:

Some test examples where batch normalization produces better results:

It is important to note that as a result of the adversarial loss function, approximating the original colors of the images is not the only goal of the generators. They also have to fill the images with realistic/believable colors in order to fool the discriminators (although, these tasks can be equivalent). In this regard, the models perform quite well, often creating colorful and lifelike samples.

Test samples where both generators struggle:

Acknowledgements

  1. Most of the work is based on the aforementioned paper and the corresponding github repo.
  2. The excellent PyTorch-GAN repository served as a great guideline for the training implementation.
  3. Spectral normalization comes from the article of Miyato et al.
  4. The spectral norm PyTorch implementation is from Christian Cosgrove's repo.

gan_image_colorizing's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

gan_image_colorizing's Issues

A question on the dropout noise

I'm sorry to disturb you. I found that in your code, there is no "model.train()" or " model.test". Therefore, dropout and BatchNormal are also used during the test. Does this have an impact on the results?

img_lab normalization

Could you please explain the meaning of this code when doing normalization for img_lab?

img_bgr = img_bgr.astype(np.float32)/255.0

transform to lab

img_lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)

normalize

img_lab[:, :, 0] = img_lab[:, :, 0]/50 - 1
img_lab[:, :, 1] = img_lab[:, :, 1]/127
img_lab[:, :, 2] = img_lab[:, :, 2]/127

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.