Coder Social home page Coder Social logo

crcrpar / pytorch.sngan_projection Goto Github PK

View Code? Open in Web Editor NEW
156.0 156.0 32.0 56 KB

An unofficial PyTorch implementation of SNGAN (ICLR 2018) and cGANs with projection discriminator (ICLR 2018).

License: MIT License

Python 97.51% Dockerfile 2.49%

pytorch.sngan_projection's Introduction

pytorch.sngan_projection's People

Contributors

crcrpar avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

pytorch.sngan_projection's Issues

implementation wrong

utils.py line 68

Hello

May i have one question?

utils.py > line 68~69
if distribution is None:
distribution == 'normal' -> I think it maybe " distribution = 'normal' "

if i'm wrong, please tell me what means ==

Thanks for your nice code :)

exp-1: cGAN with projection discriminator

Due to the lack of GPUs, I ran this on 1080Ti for 12 hours (around 20K iterations).
The dataset is tiny-ImageNet. The original images are themselves a bit difficult to tell what is in each image.

Config

same as the paper

Results

After 20K iteration.
fake_020_iter_0020000

Randomly sampled class outputs

ID 20

image_iter_0020000_batch_0020

ID 181

image_iter_0020000_batch_0181

ID 195

image_iter_0020000_batch_0195

concat discriminator forward error

Traceback (most recent call last):
  File "train_64.py", line 402, in <module>
    main()
  File "train_64.py", line 325, in main
    dis_fake = dis(fake, pseudo_y)
  File "/home/crcrpar/.pyenv/versions/env1/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/crcrpar/.ghq/github.com/crcrpar/pytorch.sngan_projection/models/discriminators/snresnet64.py", line 91, in forward
    h = self.block1(h)
  File "/home/crcrpar/.pyenv/versions/env1/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/crcrpar/.ghq/github.com/crcrpar/pytorch.sngan_projection/models/discriminators/resblocks.py", line 72, in forward
    return self.shortcut(x) + self.residual(x)
  File "/home/crcrpar/.ghq/github.com/crcrpar/pytorch.sngan_projection/models/discriminators/resblocks.py", line 75, in shortcut
    return self.c_sc(F.avg_pool2d(x, 2))
  File "/home/crcrpar/.pyenv/versions/env1/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/crcrpar/.pyenv/versions/env1/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 301, in forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [64, 3, 1, 1], expected input[64, 256, 4, 4] to have 3 channels, but got 256 channels instead

about why train failed

Hi, train 64x64 dataset with your code . But it failed , and I test something ,like D network and loss
Here is my result:

  1. sn_projection Discriminator is good, and it work fine. but wrong use in code when computed dis_fake which I think input of D should be true_label not pesudo_y . you can try it , pesudo_y is not useful at all.

  2. Hinge loss have some trouble and I was using LSGAN loss is work fine!

network structure bugs

I found some differences in the networks from https://github.com/pfnet-research/sngan_projection . Please check below.

pytorch.sngan_projection/models/discriminators/snresnet.py

activation=activation, downsample=True)

->

activation=activation, downsample=False)

pytorch.sngan_projection/models/discriminators/resblocks.py


->

h_ch = in_ch

self.c_sc = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, ksize, 1, 0))

->

self.c_sc = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, 1, 1, 0))

return F.avg_pool2d(x, 2, padding=1)

->

return F.avg_pool2d(x, 2)

pytorch.sngan_projection/models/generators/resnet.py

self.conv7 = nn.Conv2d(num_features, 3, 1, 1)

->

self.conv7 = nn.Conv2d(num_features, 3, 3, 1, 1)

bug in fid: normalization

Hi, I found a bug in fid score.

The normalization before inception model is written as:

if self.normalize_input:
x = x.clone()
x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5

but it should be:
https://github.com/mseitzer/pytorch-fid/blob/4e366b2fc9fb933bec9f6f24c5e87c3bd9452eda/inception.py#L130-L131

or it should be fixed like:

if self.normalize_input:
    mean = torch.tensor((0.485, 0.456, 0.406)).view(1, 3, 1, 1).cuda()
    std = torch.tensor((0.229, 0.224, 0.225)).view(1, 3, 1, 1).cuda()
    x = (x - mean) / std

exp-2 SNGAN

not relativistic

16K iteration
fake_latest

After 20K iterations

fake_20_iter_0020000
fake_20_iter_0020100
fake_20_iter_0020200

RuntimeError: CUDA error: initialization error

I type python train_64.py --no_tensorboard and run it...
I see the 'Initialized models...' print out and then I got a RuntimeError: CUDA error: initialization error. I am pretty sure my pyTorch version is '1.1.0'. Welcome to any solution. Thank you.

Following is the wrong feedback I got:

Traceback (most recent call last): File "train_64.py", line 402, in <module> main() File "train_64.py", line 341, in main real, y = sample_from_data(args, device, train_loader) File "train_64.py", line 199, in sample_from_data real, y = next(data_loader) File "/vm/data/rdata/shared/conda_envs/fastaiv1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 582, in __next__ return self._process_next_batch(batch) File "/vm/data/rdata/shared/conda_envs/fastaiv1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 608, in _process_next_batch raise batch.exc_type(batch.exc_msg) RuntimeError: Traceback (most recent call last): File "/vm/data/rdata/shared/conda_envs/fastaiv1/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop samples = collate_fn([dataset[i] for i in batch_indices]) File "/vm/data/rdata/shared/conda_envs/fastaiv1/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 68, in default_collate return [default_collate(samples) for samples in transposed] File "/vm/data/rdata/shared/conda_envs/fastaiv1/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 68, in <listcomp> return [default_collate(samples) for samples in transposed] File "/vm/data/rdata/shared/conda_envs/fastaiv1/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 59, in default_collate return torch.tensor(batch)

possibly typo

def __call__(self, dis_fake, dis_real=None, **kwargs):
if not self.is_relativistic:
if self.loss_type == "hinge":
return gen_hinge(dis_fake, dis_real)
elif self.loss_type == "dcgan":
return gen_dcgan(dis_fake, dis_real)

s/gen/dis/g

gen_* loss ignores the dis_real term

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.