Coder Social home page Coder Social logo

lucidrains / lightweight-gan Goto Github PK

View Code? Open in Web Editor NEW
1.6K 34.0 219.0 12.22 MB

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two

License: MIT License

Python 100.00%
artificial-intelligence deep-learning generative-adversarial-network

lightweight-gan's Introduction

512x512 flowers after 12 hours of training, 1 gpu

256x256 flowers after 12 hours of training, 1 gpu

Pizza

'Lightweight' GAN

PyPI version

Implementation of 'lightweight' GAN proposed in ICLR 2021, in Pytorch. The main contributions of the paper is a skip-layer excitation in the generator, paired with autoencoding self-supervised learning in the discriminator. Quoting the one-line summary "converge on single gpu with few hours' training, on 1024 resolution sub-hundred images".

Install

$ pip install lightweight-gan

Use

One command

$ lightweight_gan --data ./path/to/images --image-size 512

Model will be saved to ./models/{name} every 1000 iterations, and samples from the model saved to ./results/{name}. name will be default, by default.

Training settings

Pretty self explanatory for deep learning practitioners

$ lightweight_gan \
    --data ./path/to/images \
    --name {name of run} \
    --batch-size 16 \
    --gradient-accumulate-every 4 \
    --num-train-steps 200000

Augmentation

Augmentation is essential for Lightweight GAN to work effectively in a low data setting

By default, the augmentation types is set to translation and cutout, with color omitted. You can include color as well with the following.

$ lightweight_gan --data ./path/to/images --aug-prob 0.25 --aug-types [translation,cutout,color]

Test augmentation

You can test and see how your images will be augmented before it pass into a neural network (if you use augmentation). Let's see how it works on this image:

Basic usage

Base code to augment your image, define --aug-test and put path to your image into --data:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg

After this will be created the file lena_augs.jpg that will be look something like this:

Options

You can use some options to change result:

  • --image-size 256 to change size of image tiles in the result. Default: 256.
  • --aug-type [color,cutout,translation] to combine several augmentations. Default: [cutout,translation].
  • --batch-size 10 to change count of images in the result image. Default: 10.
  • --num-image-tiles 5 to change count of tiles in the result image. Default: 5.

Try this command:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg \
    --batch-size 16 \
    --num-image-tiles 4 \
    --aug-types [color,translation]

result wil be something like that:

Types of augmentations

This library contains several types of embedded augmentations.
Some of these works by default, some of these can be controlled from a command as options in the --aug-types:

  • Horizontal flip (work by default, not under control, runs in the AugWrapper class);
  • color randomly change brightness, saturation and contrast;
  • cutout creates random black boxes on the image;
  • offset randomly moves image by x and y-axis with repeating image;
    • offset_h only by an x-axis;
    • offset_v only by a y-axis;
  • translation randomly moves image on the canvas with black background;

Full setup of augmentations is --aug-types [color,cutout,offset,translation].
General recommendation is using suitable augs for your data and as many as possible, then after sometime of training disable most destructive (for image) augs.

Color

Cutout

Offset

Only x-axis:

Only y-axis:

Translation

Mixed precision

You can turn on automatic mixed precision with one flag --amp

You should expect it to be 33% faster and save up to 40% memory

Multiple GPUs

Also one flag to use --multi-gpus

Visualizing training insights with Aim

Aim is an open-source experiment tracker that logs your training runs, enables a beautiful UI to compare them and an API to query them programmatically.

First you need to install aim with pip

$ pip install aim

Next, you can specify Aim logs directory with --aim_repo flag, otherwise logs will be stored in the current directory

$ lightweight_gan --data ./path/to/images --image-size 512 --use-aim --aim_repo ./path/to/logs/

Execute aim up --repo ./path/to/logs/ to run Aim UI on your server.

View all tracked runs, each metric last tracked values and tracked hyperparameters in Runs Dashboard:

Screen Shot 2022-04-19 at 00 48 55

Compare loss curves with Metrics Explorer - group and aggregate by any hyperparameter to easily compare the runs:

Screen Shot 2022-04-12 at 16 56 35

Compare and debug generated images across training steps and runs via Images Explorer:

Screen Shot 2022-04-12 at 16 57 24

Generating

Once you have finished training, you can generate samples with one command. You can select which checkpoint number to load from. If --load-from is not specified, will default to the latest.

$ lightweight_gan \
  --name {name of run} \
  --load-from {checkpoint num} \
  --generate \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After run this command you will get folder near results image folder with postfix "-generated-{checkpoint num}".

You can also generate interpolations

$ lightweight_gan --name {name of run} --generate-interpolation

Show progress

After creating several checkpoints of model you can generate progress as sequence images by command:

$ lightweight_gan \
  --name {name of run} \
  --show-progress \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After running this command you will get a new folder in the results folder, with postfix "-progress". You can convert the images to a video with ffmpeg using the command "ffmpeg -framerate 10 -pattern_type glob -i '*-ema.jpg' out.mp4".

Show progress gif demonstration

Show progress video demonstration

Discriminator output size

The author has kindly let me know that the discriminator output size (5x5 vs 1x1) leads to different results on different datasets. (5x5 works better for art than for faces, as an example). You can toggle this with a single flag

# disc output size is by default 1x1
$ lightweight_gan --data ./path/to/art --image-size 512 --disc-output-size 5

Attention

You can add linear + axial attention to specific resolution layers with the following

# make sure there are no spaces between the values within the brackets []
$ lightweight_gan --data ./path/to/images --image-size 512 --attn-res-layers [32,64] --aug-prob 0.25

Dual Contrastive Loss

A recent paper has proposed that a novel contrastive loss between the real and fake logits can improve quality slightly over the default hinge loss.

You can use this with one extra flag as follows

$ lightweight_gan --data ./path/to/images --dual-contrast-loss

Bonus

You can also train with transparent images

$ lightweight_gan --data ./path/to/images --transparent

Or greyscale

$ lightweight_gan --data ./path/to/images --greyscale

Alternatives

If you want the current state of the art GAN, you can find it at https://github.com/lucidrains/stylegan2-pytorch

Citations

@inproceedings{
    anonymous2021towards,
    title   = {Towards Faster and Stabilized {\{}GAN{\}} Training for High-fidelity Few-shot Image Synthesis},
    author  = {Anonymous},
    booktitle = {Submitted to International Conference on Learning Representations},
    year    = {2021},
    url     = {https://openreview.net/forum?id=1Fqg133qRaI},
    note    = {under review}
}
@misc{cao2020global,
    title   = {Global Context Networks},
    author  = {Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
    year    = {2020},
    eprint  = {2012.13375},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{qin2020fcanet,
    title   = {FcaNet: Frequency Channel Attention Networks},
    author  = {Zequn Qin and Pengyi Zhang and Fei Wu and Xi Li},
    year    = {2020},
    eprint  = {2012.11879},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yu2021dual,
    title   = {Dual Contrastive Loss and Attention for GANs}, 
    author  = {Ning Yu and Guilin Liu and Aysegul Dundar and Andrew Tao and Bryan Catanzaro and Larry Davis and Mario Fritz},
    year    = {2021},
    eprint  = {2103.16748},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Sunkara2022NoMS,
    title   = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
    author  = {Raja Sunkara and Tie Luo},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.03641}
}

What I cannot create, I do not understand - Richard Feynman

lightweight-gan's People

Contributors

anomal avatar c67e708d avatar cafeal avatar captainstabs avatar clashluke avatar deklesen avatar dok11 avatar filipandersson245 avatar hnhnarek avatar ilos-vigil avatar lucidrains avatar rgaiacs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lightweight-gan's Issues

`pip install lightweight-gan` not working. ERROR: Command errored out with exit status 1: python setup.py egg_info

I tried it several times on my pc with win10 and every time i got error:

(lightweight-gan) E:\Projects\lightweight-gan>pip install lightweight-gan
Collecting lightweight-gan
  Using cached lightweight_gan-0.12.3-py3-none-any.whl (27 kB)
  Using cached lightweight_gan-0.12.2-py3-none-any.whl (27 kB)
  Using cached lightweight_gan-0.12.1-py3-none-any.whl (27 kB)
  Using cached lightweight_gan-0.12.0-py3-none-any.whl (27 kB)
  Using cached lightweight_gan-0.11.3-py3-none-any.whl (27 kB)
  Using cached lightweight_gan-0.11.2-py3-none-any.whl (27 kB)
  Using cached lightweight_gan-0.11.1-py3-none-any.whl (27 kB)
  Using cached lightweight_gan-0.11.0-py3-none-any.whl (27 kB)
Collecting adabelief-pytorch
  Using cached adabelief_pytorch-0.1.0-py3-none-any.whl (5.5 kB)
  Using cached adabelief_pytorch-0.0.5-py3-none-any.whl (4.8 kB)
INFO: pip is looking at multiple versions of lightweight-gan to determine which version is compatible with other requirements. This could take a while.
Collecting lightweight-gan
  Using cached lightweight_gan-0.10.0-py3-none-any.whl (26 kB)
  Using cached lightweight_gan-0.9.2-py3-none-any.whl (26 kB)
  Using cached lightweight_gan-0.9.1-py3-none-any.whl (26 kB)
  Using cached lightweight_gan-0.9.0-py3-none-any.whl (26 kB)
  Using cached lightweight_gan-0.8.5-py3-none-any.whl (26 kB)
  Using cached lightweight_gan-0.8.4-py3-none-any.whl (26 kB)
  Using cached lightweight_gan-0.8.3-py3-none-any.whl (26 kB)
INFO: pip is looking at multiple versions of lightweight-gan to determine which version is compatible with other requirements. This could take a while.
  Using cached lightweight_gan-0.8.2-py3-none-any.whl (26 kB)
  Using cached lightweight_gan-0.8.0-py3-none-any.whl (26 kB)
  Using cached lightweight_gan-0.7.9-py3-none-any.whl (25 kB)
  Using cached lightweight_gan-0.7.8-py3-none-any.whl (25 kB)
  Using cached lightweight_gan-0.7.7-py3-none-any.whl (25 kB)
INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. If you want to abort this run, you can press Ctrl + C to do so. To improve how pip performs, tell us what happened here: https://pip.pypa.io/surveys/backtracking
  Using cached lightweight_gan-0.7.6-py3-none-any.whl (25 kB)
Collecting einops>=0.3
  Using cached einops-0.3.0-py2.py3-none-any.whl (25 kB)
Collecting fire
  Using cached fire-0.3.1.tar.gz (81 kB)
Collecting hamburger-pytorch
  Using cached hamburger_pytorch-0.0.3-py3-none-any.whl (3.4 kB)
Collecting numpy
  Downloading numpy-1.19.4-cp36-cp36m-win_amd64.whl (12.9 MB)
     |████████████████████████████████| 12.9 MB 6.8 MB/s
Collecting pillow
  Downloading Pillow-8.0.1-cp36-cp36m-win_amd64.whl (2.1 MB)
     |████████████████████████████████| 2.1 MB 6.4 MB/s
Collecting pytorch-fid
  Using cached pytorch-fid-0.2.0.tar.gz (11 kB)
    ERROR: Command errored out with exit status 1:
     command: 'C:\Users\user\AppData\Local\Continuum\anaconda3\envs\lightweight-gan\python.exe' -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'C:\\Users\\user\\AppData\\Local\\Temp\\pip-install-maq6gjoa\\pytorch-fid_ef9965f93e55457790c769e9b5008169\\setup.py'"'"'; __file__='"'"'C:\\Users\\user\\AppData\\Local\\Temp\\pip-install-maq6gjoa\\pytorch-fid_ef9965f93e55457790c769e9b5008169\\setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'
\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' egg_info --egg-base 'C:\Users\user\AppData\Local\Temp\pip-pip-egg-info-weq0pjeo'
         cwd: C:\Users\user\AppData\Local\Temp\pip-install-maq6gjoa\pytorch-fid_ef9965f93e55457790c769e9b5008169\
    Complete output (9 lines):
    Traceback (most recent call last):
      File "<string>", line 1, in <module>
      File "C:\Users\user\AppData\Local\Temp\pip-install-maq6gjoa\pytorch-fid_ef9965f93e55457790c769e9b5008169\setup.py", line 34, in <module>
        packages=setuptools.find_packages(where='src/'),
      File "C:\Users\user\AppData\Local\Continuum\anaconda3\envs\lightweight-gan\lib\site-packages\setuptools\__init__.py", line 64, in find
        convert_path(where),
      File "C:\Users\user\AppData\Local\Continuum\anaconda3\envs\lightweight-gan\lib\distutils\util.py", line 127, in convert_path
        raise ValueError("path '%s' cannot end with '/'" % pathname)
    ValueError: path 'src/' cannot end with '/'
    ----------------------------------------
ERROR: Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.

Seems like pip tried use env in system dir of anaconda, but I have virtualenv in the project folder. I successfully executed this commands:

conda create --prefix E:\Projects\lightweight-gan\env python=3.9.0
conda activate E:\Projects\lightweight-gan\env
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

Have you some ideas what I can do with this error?

[Question] Change in performance

Updating from version 0.5.6 to the latest release 0.6.1 (at the time of this post) I've noticed a 2x increase in seconds/it using the same settings on the same dataset.
Is this expected behavior considering the changes made ?

Optimal parameters for Google Colab

Hello,

First of all, thank you for sharing your code and insights with the rest of us!

As for your code, I plan to run it for 12 hours on Google Colab, similarly to the set-up for what is shown in the README.

My datasets consists of images of 256x256 resolution, and I have started training with the following command-line:

!lightweight_gan \
 --data {image_dir} \
 --disc-output-size 5 \
 --aug-prob 0.25 \
 --aug-types [translation,cutout,color] \
 --amp \

I have noticed that the expected training time is 112.5 hours with 150k iterations (the default setting), which is consistent with the average time of 2.7 seconds per iteration shown in the log. However, it is ~ 9 times more than what is shown in the README. So I wonder if I am doing something wrong, and I see 2 solutions.

First, I could decrease the number of iterations so that it takes 12 hours, by choosing 16k iterations instead of 150k with:

 --num-train-steps 16000 \

Is it what you have done for the results shown in the README?

Second, I have noticed that I am only using 3.8 GB of GPU memory, so I could increase the batch size, as you mentioned in #13 (comment).
Edit: However, the training time increases with a larger batch size.
For instance, I am using 7.2 GB of GPU memory, and it takes 8.2 seconds per iteration, with the following:

 --batch-size 32 \
 --gradient-accumulate-every 4 \

Loss functions of every step

I have noticed that the loss functions and evaluation values fluctuate a lot and I wonder if the optimal number of steps isn't smaller than the 150000 by default. How can I get the information that is provided in each step to a csv file? Has anyone done this already?

[Errno 32] Broken pipe

Thx great GAN!
Run lightweight_gan --data ./path/to/images --image-size 512 , an error occurs.
BrokenPipeError: [Errno 32] Broken pipe

Probably due to the behavior of num_workers in pytorch on win10.
Is there any way to solve this?

Amount of training steps

If I bring down the number of training steps from 150 000 to 30 000, will the trained model be overall bad? Does it really need the 100 000 or 150 000 training steps?

Error when specifying straining settings

When I try to modify the training settings or just run the recommended command for the training settings, this error shows up:
num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)
TypeError: '>' not supported between instances of 'str' and 'int'

I can't figure out what the problem is.

How to improve quality

Thanks for putting this together! I'm having some success with this creating stylized artwork. I'm wondering what are the avenues to improve quality? It sounds like training for longer helps, along with adding attention. Is there a --network-capacity flag similar to your stylegan2 project? Should increasing the number of feature_maps fmap_max help? What about increasing the size of the latent_dim?

If we scale up to multi-GPU should we scale the learning rate a corresponding amount?

Greyscale image generation

Hi,

thank you for this repo, I've been playing with it a bit and it seems very good!
I am trying to generate greyscale images, so I modified the channel accordingly

init_channel = 4 if transparent else 1

unfortunately, this seemed to have no effect as the images generated are still RGB (even though they converge towards greyscale with time), even weirder IMO is that I can modify the number of channels for the generator and keep the original 3 for the discriminator without any issue.

I have also changed this part to no effect

convert_image_fn = partial(convert_image_to, 'RGBA' if transparent else 'L')
num_channels = 1 if not transparent else 4

Am I missing something here?

Why have the hamburger layers been removed?

More a curious question:

I remember In an earlier version of this repository there were hamburger attention layers used for the attention layers but these have later replaced with global self attention, I am just curious why?

Where the hamburger attention layers not good or did not lead to any improvements in training or visual quality?

Truncation psi not doing anything?

This may be me missing something screamingly obvious, but I can't see anywhere in the code where the trunc_psi parameter is actually used. In particular, generate_truncated looks like this:

    def generate_truncated(self, G, style, trunc_psi = 0.75, num_image_tiles = 8):
        generated_images = evaluate_in_chunks(self.batch_size, G, style)
        return generated_images.clamp_(0., 1.)

That doesn't seem to do the truncation at all, and experimentation backs this up: varying psi has no discernible effect on the generated images.

Using the generator outside of the repo

Hi, I'm trying to apply closed-form factorization of the weights to explore the latent space of a model I have trained.
Unfortunately, using anything outside of the module itself doesn't work as expected.

I've installed the repo and so I am able to run every command using 'lightweight_gan' as a command with no problems, I can generate interpolations, grid images etc.
Whenever I try to do the same outside of the repo (importing the Generator or Trainer class, or some functions) I can't reproduce the results.
Two things to notice:

  1. The image generated is not completely random, I can see that the badly-generated image (right) share some resemblance with the well-generated images (left)
    45generated-01-07-2021_21-08-56
  2. My generator has missing layers (maybe?)
    train = Trainer(name=args.modeltype, models_dir=args.weights, attn_res_layers=[32,64], image_size=256)
    train.load(325)
    print(train.GAN.G)
    gives me this
Generator Generator(

(initial_conv): Sequential(
(0): ConvTranspose2d(256, 512, kernel_size=(4, 4), stride=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): GLU(dim=1)
)
(layers): ModuleList(
(0): ModuleList(
(0): Sequential(
(0): Upsample(scale_factor=2.0, mode=nearest)
(1): Identity()
(2): Conv2d(256, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): GLU(dim=1)
)
(1): None
(2): None
(3): None
)
(1): ModuleList(
(0): Sequential(
(0): Upsample(scale_factor=2.0, mode=nearest)
(1): Identity()
(2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): GLU(dim=1)
)
(1): SLE(
(avg_pool): AdaptiveAvgPool2d(output_size=(4, 4))
(max_pool): AdaptiveMaxPool2d(output_size=(4, 4))
(net): Sequential(
(0): Conv2d(1024, 256, kernel_size=(4, 4), stride=(1, 1))
(1): LeakyReLU(negative_slope=0.1)
(2): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(3): Sigmoid()
)
)
(2): None
(3): None
)
(2): ModuleList(
(0): Sequential(
(0): Upsample(scale_factor=2.0, mode=nearest)
(1): Identity()
(2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): GLU(dim=1)
)
(1): SLE(
(avg_pool): AdaptiveAvgPool2d(output_size=(4, 4))
(max_pool): AdaptiveMaxPool2d(output_size=(4, 4))
(net): Sequential(
(0): Conv2d(512, 128, kernel_size=(4, 4), stride=(1, 1))
(1): LeakyReLU(negative_slope=0.1)
(2): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
(3): Sigmoid()
)
)
(2): None
(3): None
)
(3): ModuleList(
(0): Sequential(
(0): Upsample(scale_factor=2.0, mode=nearest)
(1): Identity()
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): GLU(dim=1)
)
(1): None
(2): None
(3): Rezero(
(fn): GSA(
(to_qkv): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)
(to_out): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
)
(4): ModuleList(
(0): Sequential(
(0): Upsample(scale_factor=2.0, mode=nearest)
(1): Identity()
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): GLU(dim=1)
)
(1): None
(2): None
(3): None
)
(5): ModuleList(
(0): Sequential(
(0): Upsample(scale_factor=2.0, mode=nearest)
(1): Identity()
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): GLU(dim=1)
)
(1): None
(2): None
(3): None
)
)
(out_conv): Conv2d(32, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

Any idea of what is happening? It feels like it is a simple problem but I can't make it work...

How can I save and load latent codes?

How can I save latent codes and load them later, to interpolate them edit them etc?.
There is a function for that, if not please add a very useful one for artists.

Losses

I think it is not clear what the losses displayed in the log are.

def print_log(self):
data = [
('G', self.g_loss),
('D', self.d_loss),
('GP', self.last_gp_loss),
('SS', self.last_recon_loss),
('FID', self.last_fid)
]

My understanding is that, in G: 1.37 | D: 1.80 | GP: 0.97 | SS: 0.02, we have:

  • the loss $L_G$ for the Generator (G),
  • the (total) loss $L_D$ for the Discriminator (D),
  • a loss from Gradient Penalty (GP),
  • a reconstruction loss from the Self-Supervision (SS) of the the discriminator.

It would be nice to mention it somewhere so that we can try to understand what happens during training. :)

RuntimeError: leaf variable has been moved into the graph interior

Did someone meet this error?
I have been running training for several days from previous stage and today it throw error without any changes including nvidia drivers.
Error trace:

(lgan) E:\Projects\lightweight-gan>lightweight_gan --data ./pano/1024 --image-size 1024 --aug-types [offset_h] --aug-prob 1 --name pano-1024 --amp --sle-spatial --batch-size 4 --num-train-steps 1000000
continuing from previous epoch - 71
loading from version 0.14.1
pano-1024<./pano/1024>:   7%|██████████████████████████████▎                                                                                                                                                                                                                                                                                                                                                                                                            | 71000/1000000 [00:15<?, ?it/s]
Traceback (most recent call last):
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan\Scripts\lightweight_gan-script.py", line 33, in <module>
    sys.exit(load_entry_point('lightweight-gan', 'console_scripts', 'lightweight_gan')())
  File "e:\projects\lightweight-gan\lightweight-gan\lightweight_gan\cli.py", line 164, in main
    fire.Fire(train_from_folder)
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan\lib\site-packages\fire\core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan\lib\site-packages\fire\core.py", line 468, in _Fire
    target=component.__name__)
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan\lib\site-packages\fire\core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "e:\projects\lightweight-gan\lightweight-gan\lightweight_gan\cli.py", line 155, in train_from_folder
    run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed)
  File "e:\projects\lightweight-gan\lightweight-gan\lightweight_gan\cli.py", line 60, in run_training
    retry_call(model.train, tries=3, exceptions=NanException)
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan\lib\site-packages\retry\api.py", line 101, in retry_call
    return __retry_internal(partial(f, *args, **kwargs), exceptions, tries, delay, max_delay, backoff, jitter, logger)
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan\lib\site-packages\retry\api.py", line 33, in __retry_internal
    return f()
  File "e:\projects\lightweight-gan\lightweight-gan\lightweight_gan\lightweight_gan.py", line 1039, in train
    self.D_scaler.scale(disc_loss).backward()
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan\lib\site-packages\torch\tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan\lib\site-packages\torch\autograd\__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: leaf variable has been moved into the graph interior

lightweight_gan.py", line 1039, in train self.D_scaler.scale(disc_loss).backward():
https://github.com/lucidrains/lightweight-gan/blob/main/lightweight_gan/lightweight_gan.py#L1039

Also I tried the run from previous steps by --load-from. And it has no another effect.
And of course I tried to restart my PC =)

.tif support

Any chance we could get support for .tif files as well?

Unable to use "Show Progress"

Summary: Running lightweight_gan --models_dir "path" --show-progress results in a AttributeError: 'NoneType' object has no attribute 'split'

Full error:
Generating progress images: 0% 0/30 [00:04<?, ?it/s] Traceback (most recent call last): File "/usr/local/bin/lightweight_gan", line 8, in <module> sys.exit(main()) File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/cli.py", line 185, in main fire.Fire(train_from_folder) File "/usr/local/lib/python3.6/dist-packages/fire/core.py", line 138, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/usr/local/lib/python3.6/dist-packages/fire/core.py", line 468, in _Fire target=component.__name__) File "/usr/local/lib/python3.6/dist-packages/fire/core.py", line 672, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/cli.py", line 166, in train_from_folder model.show_progress(num_images=num_image_tiles, types=generate_types) File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 26, in decorate_context return func(*args, **kwargs) File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/lightweight_gan.py", line 1217, in show_progress generated_image = self.generate_truncated(self.GAN.G, latents) File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 26, in decorate_context return func(*args, **kwargs) File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/lightweight_gan.py", line 1272, in generate_truncated generated_images = evaluate_in_chunks(self.batch_size, G, style) File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/lightweight_gan.py", line 99, in evaluate_in_chunks split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/lightweight_gan.py", line 99, in <lambda> split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) AttributeError: 'NoneType' object has no attribute 'split'

I've tried populating all the fields listed in the readme just to check with no luck.

For additional context here's the command used to train the model and generate the checkpoints:
lightweight_gan --num-train-steps 150000 --data /content/images --optimizer "adabelief" --attn-res-layers [32,64,128] --image-size 512 --disc-output-size 5 --models_dir "/LightGAN/" --results_dir "/LightGAN/results" --calculate_fid_every 25000

running in jupyter

lightweight_gan --data /lightweight-gan/images --image-size 512
while running in jupyter notebook, it is giving syntax error. please help

ZeroDivisionError: float division by zero

Catch error on 472 676 step (batch size 3)

Traceback (most recent call last):
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan_p3.8\Scripts\lightweight_gan-script.py", line 33, in <module>
    sys.exit(load_entry_point('lightweight-gan', 'console_scripts', 'lightweight_gan')())
  File "\lightweight_gan\cli.py", line 185, in main
    fire.Fire(train_from_folder)
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan_p3.8\lib\site-packages\fire\core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan_p3.8\lib\site-packages\fire\core.py", line 463, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan_p3.8\lib\site-packages\fire\core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "\lightweight_gan\cli.py", line 176, in train_from_folder
    run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed)
  File "\lightweight_gan\cli.py", line 66, in run_training
    retry_call(model.train, tries=3, exceptions=NanException)
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan_p3.8\lib\site-packages\retry\api.py", line 101, in retry_call
    return __retry_internal(partial(f, *args, **kwargs), exceptions, tries, delay, max_delay, backoff, jitter, logger)
  File "C:\Users\oleg\AppData\Local\Continuum\anaconda3\envs\lgan_p3.8\lib\site-packages\retry\api.py", line 33, in __retry_internal
    return f()
  File "\lightweight_gan\lightweight_gan.py", line 1055, in train
    inv_scale = (1. / self.D_scaler.get_scale()) if self.amp else 1.
ZeroDivisionError: float division by zero

Spectral Normalization

I noticed that the paper mentions spectral normalization (Miyato et al., 2018) but I don't see it used in your implementation. Have you tried it?

Troubles with global context module in 0.15.0

@lucidrains

After update to this version https://github.com/lucidrains/lightweight-gan/releases/tag/0.15.0 I cant continue train my network and did start in from zero.
Previous version was in state 117k batches by 4 (468k images, around 66 hours of trainig) image and was pretty good.
In new version 0.15.0 on same dataset with same parameters (--image-size 1024 --aug-types [color,offset_h] --aug-prob 1 --amp --batch-size 7) after 77k batches by 7 (539k images, around 49 hours of training) I see some bugs like oil puddle. Did you meet this or do you know how avoid this?

image

In previous version with sle-spatial I didnt meet something like this.

Resume training

Is there a way to resume training from saved checkpoint?
I don't see in Readme any commands for that.

training does not start

training does not start

when i run
! python /content/lightweight-gan/lightweight_gan/lightweight_gan.py --data / content / faces / --image-size 256
does not start training the cell finishes executing immediately and nothing happens

I am using colab.

AdaBelief and AMP seem incompatible

I tried the AdaBelief optimizer in combination with Mixed precission training (AMP) but after a few 100 iterations it goes out of control for me and eventually the losses end up with NaNs and then it eventually it crashes out with a Division by Zero.

Anyone found stable training parameters with AdaBelief and AMP?
Haven't seen that issue with normal FP32 training.

Do I need Nvidia GPU with CUDA to generate images from a model?

When trying to generate images from trained model I get the error:

  return torch._C._cuda_getDeviceCount() > 0
Traceback (most recent call last):
  File "/home/magic/.local/bin/lightweight_gan", line 5, in <module>
    from lightweight_gan.cli import main
  File "/home/magic/.local/lib/python3.8/site-packages/lightweight_gan/__init__.py", line 1, in <module>
    from lightweight_gan.lightweight_gan import LightweightGAN, Generator, Discriminator, Trainer, NanException
  File "/home/magic/.local/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 41, in <module>
    assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
AssertionError: You need to have an Nvidia GPU with CUDA installed.

Do I really need Nvidia card to generate images from pre-trained model? I wanted to generate images in my VPS as a web service...

Is it possible to add the experimental options from your StyleGAN2 to this GAN model?

Hello, this GAN you provided is very meaningful for my research as I have a limited set of images (only close to 9,000).
I would like to ask if you have the possibility to add your experimental features in StyleGAN2 (https://github.com/lucidrains/stylegan2-pytorch#experimental) to the training of this GAN model? This could enrich my research.
Your experimental options in StyleGAN2 are these.
1、Top-k Training for Generator
2、Feature Quantization
3、Contrastive Loss Regularization
4、Relativistic Discriminator Loss

FID calculation arguments missing device?

The args to calculate_fid_given_paths are given here

return fid_score.calculate_fid_given_paths([real_path, fake_path], 256, True, 2048)

However looking at the implementation in pytorch-fid, it seems to expect the device rather than the bool passed above:
https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py#L240
def calculate_fid_given_paths(paths, batch_size, device, dims):

FID calculation resaves real images everytime

What is the reasoning behind deleting the "fid_real" image folder every time, and then recreating and re-saving all the real images everytime the FID calculation is run?

torchvision.utils.save_image(real_batch[k, :, :, :], real_path + '{}.png'.format(k + batch_num * self.batch_size))

This can take a surprising amount of time for large datasets and I think it's unnecessary, as for the real images this only needs to occur once, unless I'm mistaken?

Error when installing

C:\Users\Lenovo>pip install lightweight-gan WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip. Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue. To avoid this problem you can invoke Python with '-m pip' instead of running pip directly. Looking in indexes: http://pypi.douban.com/simple Collecting lightweight-gan Downloading http://pypi.doubanio.com/packages/3b/be/3cc7cf30cb6c121d323d2effc5758a24f464774a5919adb4308315860516/lightweight_gan-0.14.1-py3-none-any.whl (28 kB) Requirement already satisfied: numpy in c:\users\lenovo\appdata\roaming\python\python36\site-packages (from lightweight-gan) (1.18.3) Requirement already satisfied: torchvision in c:\programdata\anaconda3\lib\site-packages (from lightweight-gan) (0.6.0+cpu) Requirement already satisfied: tqdm in c:\users\lenovo\appdata\roaming\python\python36\site-packages (from lightweight-gan) (4.45.0) Requirement already satisfied: scipy in c:\programdata\anaconda3\lib\site-packages (from lightweight-gan) (1.4.1) Collecting einops>=0.3 Downloading http://pypi.doubanio.com/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl (25 kB) Requirement already satisfied: pillow in c:\users\lenovo\appdata\roaming\python\python36\site-packages (from lightweight-gan) (6.0.0) Collecting gsa-pytorch Downloading http://pypi.doubanio.com/packages/f5/d9/edcfbc07155cf8c9757728e7a46077f0387786bdde07cffc14ad06ea790b/gsa_pytorch-0.2.2-py3-none-any.whl (3.6 kB) Collecting retry Downloading http://pypi.doubanio.com/packages/4b/0d/53aea75710af4528a25ed6837d71d117602b01946b307a3912cb3cfcbcba/retry-0.9.2-py2.py3-none-any.whl (8.0 kB) Collecting fire Downloading http://pypi.doubanio.com/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81 kB) |████████████████████████████████| 81 kB 1.3 MB/s Collecting kornia Downloading http://pypi.doubanio.com/packages/25/f6/9fb4cc2c67796680c8041fa6ffdee5f280e4cf65c86835768a700a324d59/kornia-0.4.1-py2.py3-none-any.whl (225 kB) |████████████████████████████████| 225 kB 226 kB/s Collecting pytorch-fid Downloading http://pypi.doubanio.com/packages/93/54/49dc21a5ee774af0390813c3cf66af57af0a31ab22ba0c2ac02cdddeb755/pytorch-fid-0.2.0.tar.gz (11 kB) ERROR: Command errored out with exit status 1: command: 'C:\ProgramData\Anaconda3\python.exe' -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'C:\\Users\\Lenovo\\AppData\\Local\\Temp\\pip-install-d1z9rk7k\\pytorch-fid\\setup.py'"'"'; __file__='"'"'C:\\Users\\Lenovo\\AppData\\Local\\Temp\\pip-install-d1z9rk7k\\pytorch-fid\\setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' egg_info --egg-base 'C:\Users\Lenovo\AppData\Local\Temp\pip-pip-egg-info-7fajsvz9' cwd: C:\Users\Lenovo\AppData\Local\Temp\pip-install-d1z9rk7k\pytorch-fid\ Complete output (9 lines): Traceback (most recent call last): File "<string>", line 1, in <module> File "C:\Users\Lenovo\AppData\Local\Temp\pip-install-d1z9rk7k\pytorch-fid\setup.py", line 34, in <module> packages=setuptools.find_packages(where='src/'), File "C:\Users\Lenovo\AppData\Roaming\Python\Python36\site-packages\setuptools\__init__.py", line 71, in find convert_path(where), File "C:\ProgramData\Anaconda3\lib\distutils\util.py", line 127, in convert_path raise ValueError("path '%s' cannot end with '/'" % pathname) ValueError: path 'src/' cannot end with '/' ---------------------------------------- ERROR: Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.

Shape mismatch, can't divide axis of length 3 in chunks of 2

I'm running with the following arguments:

lightweight_gan --data images --image-size 1024 --aug-prob 0.2 --batch-size 16 --gradient-accumulate-every 4 --attn-res-layers [32,64] --amp

But getting the following error:

einops.EinopsError: Error while processing rearrange-reduction pattern "b (g c) h w -> b g c h w". Input tensor shape: torch.Size([16, 3, 512, 512]). Additional info: {'g': 2}. Shape mismatch, can't divide axis of length 3 in chunks of 2

Getting NoneType is not subscriptable when trying to start training.

I've been able to train models before but after changing my dataset I'm getting the error.

My trace:
File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/lightweight_gan.py", line 1356, in load
name = checkpoints[-1]
TypeError: 'NoneType' object is not subscriptable

AssertionError: image size must be a power of 2

Thanks for this implementation, looks very promising. It's training much faster than the stylegan2-pytorch.

Release 0.9 was working well but now with 0.9.1 I get the following error no matter what image size I use:

File "c:\users\timothy\miniconda3\envs\lwgan\lib\site-packages\lightweight_gan\lightweight_gan.py", line 308, in __init__ assert is_power_of_two(resolution), 'image size must be a power of 2' AssertionError: image size must be a power of 2

Multiclass training and inference

I'd like to lightweight-gan for a multiclass dataset.

The idea is to train the GAN with multiclass.

And during the inference, ask the GAN for an image with multiples tags. I.e. generate an image tagged as 'boat', 'sunset' and 'people'

Is it possible with lightweight-gan?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.