Coder Social home page Coder Social logo

gan_intrinsicreward's Introduction

Generative Adversarial Network (GAN) with Intrinsic Reward

DCGAN with a simple intrinsic reward implemention for novel image generation


Abstract

This project implements a Generative Adversarial Network (GAN) with Intrinsic Reward, which is a modification of the traditional GAN model. The intrinsic reward provides a reward signal to the generator during training, in addition to the adversarial loss, and is designed to encourage the generator to produce diverse and novel images.

This project was an experiment to implement a simple intrinsic reward model to a generative AI model for image generation to see if it can guide the generator to novel image generation.

GAN

The GAN model is trained using PyTorch, and it consists of a generator and a discriminator. The generator generates fake images from random noise, while the discriminator tries to distinguish between real and fake images. During training, the generator tries to improve its ability to generate realistic images, while the discriminator tries to improve its ability to distinguish between real and fake images.

Intrinsic Reward

The intrinsic reward was used to guide the generator's learning process in addition to the usual adversarial loss. After generating fake images from random noise, the generator's output was passed through the intrinsic reward model to compute the intrinsic reward loss. This loss was then added to the generator loss to obtain the total loss used to update the generator's parameters. The intrinsic reward loss is designed to encourage the generator to produce images that contain certain desirable features or qualities beyond simply being realistic-looking. In this implementation, the intrinsic reward model's output was compared to the original random noise used to generate the fake images, and the difference (L2 loss or squared L2 norm) between these was used as the intrinsic reward loss.

$$loss_{IR} = ||IR(G(z)) - z||_2^2$$

Dataset

An art portrait subset of Wiki-Art: Visual Art Encyclopedia dataset was used to train the GAN model. This subset contains 4,117 art portrait images.

The dataset was found from the links below:

Wiki-Art: Visual Art Encyclopedia: https://www.kaggle.com/datasets/ipythonx/wikiart-gangogh-creating-art-gan

An art portrait subset: https://www.kaggle.com/datasets/karnikakapoor/art-portraits

image

Data Preprocessing

The training dataset, consisting of art portraits, was preprocessed to reduce the image size to 64 x 64 for faster training and better memory efficiency. This was done using the PyTorch transforms module, which applies a series of image transformations to the dataset.

First, the images were resized to the desired size using transforms.Resize(img_size). Then, a center crop of the same size was taken using transforms.CenterCrop(img_size) to ensure that all images are of the same size. To increase the diversity of the training dataset, transforms.RandomHorizontalFlip(p=0.5) was applied to randomly flip the images horizontally with a probability of 0.5.

To further augment the dataset, random color jitter and rotation were applied to the images using transforms.ColorJitter() and transforms.RandomRotation(degrees=20). These transforms were randomly applied to each image with a probability of 0.2 using transforms.RandomApply(random_transforms, p=0.2).

After the image transforms, the images were converted to PyTorch tensors using transforms.ToTensor(), and then normalized using transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)). This ensures that the pixel values are in the range of [-1, 1], which is suitable for training GANs.

The preprocessed dataset was then loaded into a PyTorch DataLoader with a specified batch size of 32, using DataLoader(train_data, batch_size=batch_size, shuffle=True). The shuffle=True argument ensures that the images are randomly shuffled before being loaded into each batch during training.

Finally, to check that the preprocessing was done correctly, a batch of images and their corresponding labels were extracted from the train_loader using imgs, label = next(iter(train_loader)). The images were then transposed to have a shape of (batch_size, height, width, channels) using imgs = imgs.numpy().transpose(0, 2, 3, 1).

image

Demonstration

To compare the generated images of the GAN without intrinsic reward and the GAN with intrinsic reward, I set the random seed as 3407 and applied weight initialization to the generator, discriminator, and intrinsic reward networks. This ensured that both models used the same batch of real input images and a fixed (latent) noise vector for generating testing images.

During the training process, I used a 128-dimensional latent vector as input to the generator. This means there are 128 scalar values (noise) that the generator can adjust to produce different outputs. The larger the dimensionality of the latent space, the more complex outputs can be generated, but the training process would be slower.

GAN without Intrinsic Reward after 50 epochs of training:

generated_images_no_intrinsic_reward

image

GAN with Intrinsic Reward after 50 epochs of training

generated_images_intrinsic_reward

image

Challenge

It is hard to compare the novelty of two set of images. Therefore, further research is needed to explore metrics for novelty comparison.

Some possible metrics are:

  • Wasserstein Distance
  • Entropy Score
  • Information Gain (not a metric)

Reference

Some of my work (GAN concept, architecture, training process, etc.) was inspired by the following:

gan_intrinsicreward's People

Contributors

sjhpark avatar

Stargazers

ccmr avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.