akanimax / big-discriminator-batch-spoofing-gan Goto Github PK
View Code? Open in Web Editor NEWBMSG-GAN with more features
License: MIT License
BMSG-GAN with more features
License: MIT License
I run a few hours on 8GPUs wihtout any progress. Each sample is pixelwise copy of each other in all layers.
setup
git clone [email protected]:akanimax/BBMSG-GAN.git
conda create -n bbmsg python==3.7
conda activate bbmsg
conda install pytorch torchvision cudatoolkit=10.0 cudnn scipy==1.2.0 tensorboard -c pytorch
pip install tensorboardX tqdm
training
# calc real fid stats before training
python ../BBMSG-GAN/sourcecode/train.py \
--images_dir="$IMGS" \
--sample_dir="$SAMPLES" \
--model_dir="$MODELS" \
--depth=9 \
--batch_size=24 \
--num_samples=36 \
--feedback_factor=5 \
--checkpoint_factor=1 \
--num_epochs=50000 \
--num_workers=90 \
--log_fid_values=True \
--fid_temp_folder=/tmp/fid_tmp \
--fid_real_stats="$FID" \
--fid_batch_size=64 \
--num_fid_images=5000
There exists an apparent error in the implementation of LSGAN below.
Right implementation:
class LSGAN(GANLoss):
def __init__(self, dis):
super().__init__(dis)
def dis_loss(self, real_samps, fake_samps):
real_scores = th.mean((self.dis(real_samps) - 1) ** 2)
fake_scores = th.mean(self.dis(fake_samps) ** 2)
return 0.5 * (real_scores + fake_scores)
def gen_loss(self, _, fake_samps):
return 0.5 * th.mean((self.dis(fake_samps) - 1) ** 2)
reference equation 8 of [Least Squares Generative Adversarial Networks].
Hi there, I've been trying to get BBMSG-GAN running but haven't been able to get training started. I have two 11 GB 1080 Ti cards on Ubuntu 18.10 with pytorch 1.0.1 and python 3.6.
Once I run train.py (with python train.py --depth=8 --batch_size=32 --fid_batch_size=32 --spoofing_factor=64 --latent_size=512 --images_dir=datasets/epskal/ --sample_dir=samples/epskal_1 --model_dir=models/epskal_1
) it gets to the first epoch and then stops doing anything. I can see that the memory has been allocated on my GPUs, but the utilization is 0%, so it's not actually doing any work. I've tried waiting overnight to see if it would start processing, but the next morning nothing had happened.
When trying to quit, the processes hang and I can't use my GPUs again without restarting the computer. I've added the logs below, here I've tried Ctrl+C'ing twice, afterwhich nothing else works to kill the processes. The same thing happens whether I've set the discriminator to be parallel or not.
Starting the training process ...
Epoch: 1
^CTraceback (most recent call last):
File "train.py", line 310, in <module>
main(parse_arguments())
File "train.py", line 304, in main
fid_batch_size=args.fid_batch_size
File "/home/hans/BBMSG-GAN/sourcecode/MSG_GAN/GAN.py", line 563, in train
num_accumulations=spoofing_factor)
File "/home/hans/BBMSG-GAN/sourcecode/MSG_GAN/GAN.py", line 310, in optimize_discriminator
loss = loss_fn.dis_loss(real_batch, fake_samples) / num_accumulations
File "/home/hans/BBMSG-GAN/sourcecode/MSG_GAN/Losses.py", line 200, in dis_loss
r_preds = self.dis(real_samps)
File "/home/hans/.conda/envs/pix/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/home/hans/BBMSG-GAN/sourcecode/MSG_GAN/GAN.py", line 197, in forward
y = self.rgb_to_features[self.depth - 2](inputs[self.depth - 1])
File "/home/hans/.conda/envs/pix/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/home/hans/.conda/envs/pix/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/hans/.conda/envs/pix/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/hans/.conda/envs/pix/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 75, in parallel_apply
thread.join()
File "/home/hans/.conda/envs/pix/lib/python3.6/threading.py", line 1056, in join
self._wait_for_tstate_lock()
File "/home/hans/.conda/envs/pix/lib/python3.6/threading.py", line 1072, in _wait_for_tstate_lock
elif lock.acquire(block, timeout):
KeyboardInterrupt
^CException ignored in: <module 'threading' from '/home/hans/.conda/envs/pix/lib/python3.6/threading.py'>
Traceback (most recent call last):
File "/home/hans/.conda/envs/pix/lib/python3.6/threading.py", line 1294, in _shutdown
t.join()
File "/home/hans/.conda/envs/pix/lib/python3.6/threading.py", line 1056, in join
self._wait_for_tstate_lock()
File "/home/hans/.conda/envs/pix/lib/python3.6/threading.py", line 1072, in _wait_for_tstate_lock
elif lock.acquire(block, timeout):
KeyboardInterrupt
Any idea what I might be doing wrong / how I can get it working correctly?
P.S. am I right to assume that as long as batch_size * spoofing_factor = 2048 the results will be the same? Because without decreasing the batch size quite a bit I run out of memory especially with 1024x1024.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.