Coder Social home page Coder Social logo

yiyang7 / super_resolution_with_cnns_and_gans Goto Github PK

View Code? Open in Web Editor NEW
268.0 7.0 73.0 6.5 MB

Image Super-Resolution Using SRCNN, DRRN, SRGAN, CGAN in Pytorch

Python 19.11% Jupyter Notebook 80.15% MATLAB 0.73% M 0.01%
python pytorch computer-vision superresolution srcnn srresnet drrn srgan cgan super-resolution

super_resolution_with_cnns_and_gans's Introduction

Super Resolution with CNNs and GANs

This is the code for our cs231n project.

Super Resolution with CNNs and GANs,
Yiyang Li, Yilun Xu, Ji Yu

We investigated the problem of image super-resolution (SR), where we want to reconstruct high-resolution images from low-resolution images. We presented a residual learning framework to ease the training of the substantially deep network. Specifically, we reformulated the structure of the deep-recursive neural network to improve its performance. To further improve image qualities, we built a super-resolution generative adversarial network (SRGAN) framework, where we proposed several loss functions based on perceptual loss, i.e. SSIM loss and/ or total variation (TV) loss, to enhance the structural integrity of generative images. Moreover, a condition is injected to resolve the problem of partial information loss associated with GANs.

The results show that our methods and trails can achieve equivalent performance on most of the benchmarks compared with the previous state-of-art methods, and out-perform them in terms of the structural similarity. Here are a few example outputs:

If you find this code useful in your project, please star this repository and cite:

@inproceedings{densecap,
  title={Super Resolution with CNNs and GANs},
  author={Yiyang, Li and Yilun, Xu and Ji, Yu},
  year=2018,
}

Installation

This project was implemented in PyTorch 0.4 and Python3

Quickstart

  1. Build datasets: First, you need to download the dataset. We use CelebA as our dataset. e.g for cnn_based model, we set input size 144×144 and output size 144×144; For gan model, we set input size 36×36 and output size 144×144
python build_dataset.py --data_dir ../img_align_celeba_test --output_dir ../data/cnn_faces --input_size 144 --output_size 144
  1. Train your experiment

for cnn-based models, e.g SRCNN:

python train_cnn.py --data_dir ../data/cnn_faces --model_dir experiments/srcnn_model --model srcnn --cuda cuda0 --optim adam

for gan models, e.g SRGAN:

python train_gan.py --data_dir ../data/gan_faces --model_dir experiments/gan_model --model gan --cuda cuda0 --optim adam

If you want to train the model from last time, add this:

--restore_file "best"
  1. Perform hyperparameters search e.g srcnn from cnn_based model
python search_hyperparams.py --data_dir ../data/cnn_faces --parent_dir experiments/learning_rate --model srcnn --model_type cnn
  1. Display the results of the hyperparameters search in a nice format
python synthesize_results.py --parent_dir experiments/learning_rate
  1. Evaluation on the test set

for cnn-based models, e.g SRCNN:

python evaluate_cnn.py --data_dir ../data/cnn_faces --model_dir experiments/srcnn_model --model srcnn --cuda cuda0

for gan models, e.g SRGAN:

python evaluate_gan.py --data_dir ../data/gan_faces --model_dir experiments/gan_model --model gan --cuda cuda0

Result List

  1. SRCNN

  1. DRRN

  1. SRGAN

  1. CGAN

Pre-trained Models

The following Pytorch models were trained on CelebA dataset:

  1. SRCNN(best.pth.tar)
  2. DRRN(best.pth.tar)
  3. SRGAN(best.pth.tar)
  4. CGAN(best.pth.tar)

Reference

[1] CS230 Stanford, https://github.com/cs230-stanford/, cs230-code-examples, 2018.

[2] tyshiwo, Drrn CVPR17, https://github.com/tyshiwo/DRRN_CVPR17, 2017.

[3] leftthomas, Srgan, https://github.com/leftthomas/SRGAN, 2017.

[4] znxlwm, pytorch-generative-modelcollections, https://github.com/znxlwm/pytorch-generative-model-collections, 2017.

super_resolution_with_cnns_and_gans's People

Contributors

narusaku avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

super_resolution_with_cnns_and_gans's Issues

RuntimeError: input and target shapes do not match: input [16 x 512 x 8 x 8], target [16 x 512 x 2 x 2]

hey i'm trying to load image 128x128 and i have resized it into 64x64 in data_loader but i still get this error

Traceback (most recent call last):
File "train_gan.py", line 245, in
args.restore_file,cuda_id=cuda_id)
File "train_gan.py", line 127, in train_and_evaluate
train(netG, netD, optimG, optimD, loss_fn, train_dataloader, metrics, params, cuda_id)
File "train_gan.py", line 66, in train
g_loss = loss_fn(fake_out, fake_img, real_img)
File "C:\Users\PC\Anaconda3\envs\cgan\lib\site-packages\torch\nn\modules\module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "F:\ProgramTA\CGAN_AinilMardiah\program\loss.py", line 78, in forward
perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
File "C:\Users\PC\Anaconda3\envs\cgan\lib\site-packages\torch\nn\modules\module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "C:\Users\PC\Anaconda3\envs\cgan\lib\site-packages\torch\nn\modules\loss.py", line 421, in forward
return F.mse_loss(input, target, reduction=self.reduction)
File "C:\Users\PC\Anaconda3\envs\cgan\lib\site-packages\torch\nn\functional.py", line 1716, in mse_loss
return _pointwise_loss(lambda a, b: (a - b) ** 2, torch._C._nn.mse_loss, input, target, reduction)
File "C:\Users\PC\Anaconda3\envs\cgan\lib\site-packages\torch\nn\functional.py", line 1674, in _pointwise_loss
return lambd_optimized(input, target, reduction)
RuntimeError: input and target shapes do not match: input [16 x 512 x 8 x 8], target [16 x 512 x 2 x 2] at c:\programdata\miniconda3\conda-bld\pytorch_1532505617613\work\aten\src\thcunn\generic/MSECriterion.cu:12

could help me how to solve the error?

why SRGAN loss_D is 0?

I think if loss_D is 0, it is that discriminator can easy to make a distinction between real image and fake image, and this is not good, so why SRGAN loss_D is 0?

TVLoss question.

I found tv loss formula in wiki:

while your code shows you did not have a sqrt operation.
could you tell me why?

RuntimeError: CUDA error: out of memory

Traceback (most recent call last):
File "train_gan.py", line 271, in
train_and_evaluate(netG, netD, train_dl, val_dl, optimG, optimD, loss_fn, metrics, params, opt.model_dir, opt.restore_file,cuda_id=cuda_id)
File "train_gan.py", line 142, in train_and_evaluate
train(netG, netD, optimG, optimD, loss_fn, train_dataloader, metrics, params, cuda_id)
File "train_gan.py", line 75, in train
fake_out = netD(fake_img).mean()
File "C:\Users\PC\Anaconda3\envs\cgan\lib\site-packages\torch\nn\modules\module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "F:\ProgramTA\CGAN_AinilMardiah\program\model\cgan.py", line 100, in forward
return torch.sigmoid(self.net(x).view(batch_size))
File "C:\Users\PC\Anaconda3\envs\cgan\lib\site-packages\torch\nn\modules\module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "C:\Users\PC\Anaconda3\envs\cgan\lib\site-packages\torch\nn\modules\container.py", line 91, in forward
input = module(input)
File "C:\Users\PC\Anaconda3\envs\cgan\lib\site-packages\torch\nn\modules\module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "C:\Users\PC\Anaconda3\envs\cgan\lib\site-packages\torch\nn\modules\activation.py", line 447, in forward
return F.leaky_relu(input, self.negative_slope, self.inplace)
File "C:\Users\PC\Anaconda3\envs\cgan\lib\site-packages\torch\nn\functional.py", line 755, in leaky_relu
return torch._C._nn.leaky_relu(input, negative_slope)
RuntimeError: CUDA error: out of memory

I have tried to reduce batch size but i still get this error, could you help me how to resolve this error?

License

Under which open-source license are you releasing this software and the trained weights ?
I cannot find a License file within the repository.
Thank you for clearing this up

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.