Coder Social home page Coder Social logo

theopfr / cycle-gan-pytorch Goto Github PK

View Code? Open in Web Editor NEW
10.0 1.0 0.0 103.14 MB

This repository contains an implementation of the Cylce-GAN architecture for style transfer along with instructions to train on an own dataset.

Python 100.00%
cyclegan cycle-gan gan machine-learning style-transfer generative-adversarial-network generative-art pytorch

cycle-gan-pytorch's Introduction

๐Ÿด๐Ÿ”„๐Ÿฆ“ Cycle-GAN implemented in PyTorch

This repository contains an implementation of the Cylce-GAN architecture as proposed in the original paper along with instructions to train on an own dataset.


โฌ‡๏ธ setup:

1. install repository:

git clone https://github.com/theopfr/cycle-gan-pytorch.git
cd cycle-gan-pytorch

2. install requirements:

Requirements: Python>=3.7, Pytorch, torchvision, tqdm, numpy

pip install -r requirements.txt

๐Ÿ‹๏ธ train:

1. create dataset:

  • create a folder inside datasets/ with a descriptive name to store your dataset
  • create two sub-folders trainA and trainB
  • put all the images of one of the two image categories in one of the folders (e.g put all the images of horses in trainA and all the images of zebras in trainB)
You can find datasets here.

2. run the train script:

  • navigate to src/
  • run the train.py script by specifying the train arguments and hyperparameters with command line flags (find the train arguments in the table below; the run_name and dataset_name flag have to be set)
  • example:
    python .\train.py --run_name "horse-zebra-run" --dataset_name "horse-zebra-dataset" --save_image_intervall 50 --resume False --epochs 100 --image_size 256 --batch_size 1 --num_res_blocks 9 --lr 0.0002 --lr_decay_rate 1 --lr_decay_intervall 200 --gaussian_noise_rate 0.05 --lambda_adversarial 1 --lambda_cycle 10 --lambda_identity 1 
    

๐Ÿšฉ train script flags/arguments:

argument type default description
run_name str - Name for the train run (a folder with this name will be created inside runs/ to store train metrics, model checkpoints and generated images).
dataset_name str - Name of the folder which holds the dataset to train on.
resume str False Options: "True", "False"; specifies if the train run should be continued if it was previously interrupted (if set to "False", the run-folder will be reinitialized).
save_image_intervall int 50 Specifies after how many iterations (not epochs!) generated images should be saved to the run-folder.
epochs int 100 The amount of epochs to train.
image_size int 256 The image size to which all images with be resized (images will be quadratic).
batch_size int 1 The batch size.
num_res_blocks int 9 Amount of residual blocks in the generator model.
lr float 0.0002 The learning rate.
lr_decay_rate float 1.0 Decay rate of the learning rate (will be multiplyed to the learning rate, therefore 1.0 means no decay).
lr_decay_intervall int 200 Specifies after how many iterations (not epochs!) the learning rate should be decayed (has to be >=1).
gaussian_noise_rate float 0.05 Specifies how much gaussian noise will be applied to images before being fed into the discriminator model (will be multiplied with random noise and then added to the images).
lambda_adversarial int 1 Specifies how much to weight the adverarial loss (will be multiplied with the loss).
lambda_cycle int 10 Specifies how much to weight the cycle loss (will be multiplied with the loss).
lambda_identity int 1 Specifies how much to weight the identity loss (will be multiplied with the loss).
All the default values are chosen as in the original paper to train on the horse-zebra dataset.

cycle-gan-pytorch's People

Contributors

theopfr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.