Coder Social home page Coder Social logo

mahdi-darvish / gan-augmented-pet-classifier Goto Github PK

View Code? Open in Web Editor NEW
16.0 2.0 4.0 8.17 MB

Towards Fine-grained Image Classification with Generative Adversarial Networks and Facial Landmark Detection - Paper Implementation and Supplementary Materials

Home Page: https://arxiv.org/abs/2109.00891

License: MIT License

Jupyter Notebook 61.44% Python 38.56%
generative-adversarial-network vision-transformer facial-landmarks-detection fine-grained-classification stylegan2-ada mobilenetv2 pytorch gan

gan-augmented-pet-classifier's Introduction

PWC

Maintenance Generic badge GitHub license GitHub commits Linux

GANs Augmented Pet Classifier

Towards Fine-grained Image Classification with Generative Adversarial Networks and Facial Landmark Detection

Mahdi Darvish • Mahsa Pouramini • Hamid Bahador

https://arxiv.org/abs/2109.00891

Abstract: Fine-grained classification remains a challengingtask because distinguishing categories needs learning complexand local differences. Diversity in the pose, scale, and positionof objects in an image makes the problem even more difficult.Although the recent Vision Transformer models achieve highperformance, they need an extensive volume of input data. Toencounter this problem, we made the best use of GAN-based dataaugmentation to generate extra dataset instances. Oxford-IIITPets was our dataset of choice for this experiment. It consistsof 37 breeds of cats and dogs with variations in scale, poses,and lighting, which intensifies the difficulty of the classificationtask. Furthermore, we enhanced the performance of the recentGenerative Adversarial Network (GAN), StyleGAN2-ADA modelto generate more realistic images while preventing overfitting tothe training set. We did this by training a customized versionof MobileNetV2 to predict animal facial landmarks; then, wecropped images accordingly. Lastly, we combined the syntheticimages with the original dataset and compared our proposedmethod with standard GANs augmentation and no augmentationwith different subsets of training data. We validated our workby evaluating the accuracy of fine-grained image classificationon the recent Vision Transformer (ViT) Model.

Results

The measured accuracy of the used model and FID for three different dataset conditions (Original, augmented, and augmented-cropped) in data regimes of 10, 50, and 100 percent:

Dataset Variant
10% training data 50% training data 100% training data
FID Accuracy FID Accuracy FID Accuracy
Original - 64.73 - 88.41 - 94.13
Augmented 71.1 63.32 36.4 88.70 20.7 94.93
Cropped-Augmented (Ours) 49.4 68.55 22.3 91.73 14.1 96.28


Comparison between synthetic and authentic images. This figure show (a) the original data,(b) and (c) generated images on the whole dataset, cropped and uncropped, respectively. (d) cropped images on 50%, (e) uncropped images generated on 50% subset and finally (f) and (g), cropped and uncropped images result of training on only 10% of the data. These qualitative visualizations prove the effectiveness and the interpretability of the method.


The charts explain the accuracy of the used model and FID for three different dataset conditions (Original, augmented, and cropped-augmented ) in data regimes of 10, 50, and 100 percent:


The evaluated RMSE of the trained MobileNetV2 model with and without landmark normalization:

Models RMSE
Validation Test
MobileNetV2 3.11 3.56
MobileNetV2 + Normalization 3.02 3.43

Pre-Trained Models

StyleGAN2-ADA trained on cropped pets dataset

Subset Kimg FID Acc on Vit Pre-trained Networks
10% 5120 49.4 68.55 subset_10_cropped.pkl
50% 5120 22.3 91.73 subset_50_cropped.pkl
100% 5120 14.1 96.28 subset_100_cropped.pkl

StyleGAN2-ADA trained on not cropped pets dataset

Subset Kimg FID Acc on Vit Pre-trained Networks
10% 5120 71.1 63.32 subset_10_original.pkl
50% 5120 36.4 88.70 subset_50_original.pkl
100% 5120 20.7 94.93 subset_100_original.pkl

Getting started

Dataset

Oxford-IIIT Pets Variants:

The official dataset can be reached from:

Landmark prediction datasets:

We combined the following datasets for predicting landmarks:

StyleGAN2-ADA Installation

Running Localy

Requirements for running StyleGAN localy are consist of:

  • 64-bit Python 3.7.

  • Pytorch and dependencies :

      !pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
    
  • Python libraries:

      !pip install ninja tqdm ffmpeg
    

Data Preperation

Please prepare the data as following:

└── root
    ├── gans_training
    │   ├── dataset
    │   │   └── oxford_pets
    │   │       ├── class1
    │   │       │   ├── img1.jpg
    │   │       │   ├── img2.jpg
    │   │       │   └── ...
    │   │       ├── class2
    │   │       │   ├── img1.jpg
    │   │       │   ├── img2.jpg
    │   │       │   └── ...
    │   │       └── class3
    │   │           ├── img1.jpg
    │   │           ├── img2.jpg
    │   │           └── ...
    │   │  
    │   ├── images
    │   │   
    │   └── experiment
    └── stylegan2_ada

You need to run the following code to transfer the images to multi-resolution TFRecords (Tensors for all variants of Oxford-IIIT Pets are available at dataset section):

!python root/stylegan2-ada-pytorch/dataset_tool.py --source root/gans_training/images --dest root/gans_training/dataset

Training

Note: You can also modify and add to the config paramters by checking out official StyleGAN2-ADA repo.

New Network

!python root/stylegan2-ada-pytorch/train.py --snap 30 --outdir root/gans_training/experiments --data root/gans_training/dataset

Resume Traning Takes an extra argument, you need to pass the last generated .pkl file by the network.

!python root/stylegan2-ada-pytorch/train.py --snap 30 --resume *insert path to last .pkl file* --outdir root/gans_training/experiments --data root/gans_training/dataset

Generating Images with Pre-trained Models

To generate synthetic images run the code below with your custom .pkl file and set the seeds to a range of numbers ccording to your need.

!python root/stylegan2-ada-pytorch/generate.py \
--network= *path to .pkl file* --outdir= root/results --seeds=50-70

Running through Google Colab

We provided notebooks for running the GANs model directly through google colab:

Landmark Detection

Requirements

  • 64-bit Python 3.7

  • python libraries:

     !pip install Pillow==5.4.1 tqdm keras cv2
    

Training

To train the model with custom dataset, the default directory is pet's parent dataset, you can modify it by changing train.py script. The annotations should be same as the Cat-dataset.

!python train.py --loss_fn iou --flip_horizontal --crop_scale_balanced --rotate_n 20 --ReduceLROnPlateau_factor 0.8 --ReduceLROnPlateau_patience 10

Trained models will be placed in: landmark_detection/tools/models

Prediction

Run the following code with the image path, output image will be saved at the same path:

!python final_predictor.py --img_path *insert image path here* --landmarks_model_path= *inside tools/models as default*

Citation

@misc{darvish2021finegrained,
  title={Towards Fine-grained Image Classification with Generative Adversarial Networks and Facial Landmark Detection}, 
  author={Mahdi Darvish and Mahsa Pouramini and Hamid Bahador},
  year={2021},
  eprint={2109.00891},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
  }

gan-augmented-pet-classifier's People

Contributors

mahdi-darvish avatar mahsa-pouramini avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

gan-augmented-pet-classifier's Issues

upfirdn2d_plugin error

Hello,
I encountered the following error while training on your cropped dataset:
I checked the StyleGAN2 repo and couldn't find a similar issue.
Any suggestions?

`raining duration: 25000 kimg
Number of GPUs: 1
Number of images: 48255
Image resolution: 128
Conditional model: False
Dataset x-flips: False

Creating output directory...
Launching processes...
Loading training set...

Num images: 48255
Image shape: [3, 128, 128]
Label shape: [0]

Constructing networks...
Setting up PyTorch plugin "bias_act_plugin"... Failed!
C:\Users\Graham\Desktop\ICSCapstone\stylegan2-ada-pytorch-main\torch_utils\ops\bias_act.py:50: UserWarning: Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:

Traceback (most recent call last):
File "C:\Users\Graham\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\cpp_extension.py", line 1673, in _run_ninja_build
env=env)
File "C:\Users\Graham\AppData\Local\Programs\Python\Python37\lib\subprocess.py", line 468, in run
output=stdout, stderr=stderr)
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):Traceback (most recent call last):
File "C:\Users\Graham\Desktop\ICSCapstone\stylegan2-ada-pytorch-main\torch_utils\ops\upfirdn2d.py", line 32, in _init
_plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
File "C:\Users\Graham\Desktop\ICSCapstone\stylegan2-ada-pytorch-main\torch_utils\custom_ops.py", line 110, in get_plugin
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
File "C:\Users\Graham\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\cpp_extension.py", line 1091, in load
keep_intermediates=keep_intermediates)
File "C:\Users\Graham\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\cpp_extension.py", line 1317, in _jit_compile
return _import_module_from_library(name, build_directory, is_python_module)
File "C:\Users\Graham\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\cpp_extension.py", line 1699, in _import_module_from_library
file, path, description = imp.find_module(module_name, [path])
File "C:\Users\Graham\AppData\Local\Programs\Python\Python37\lib\imp.py", line 297, in find_module
raise ImportError(_ERR_MSG.format(name), name=name)
ImportError: No module named 'upfirdn2d_plugin'`

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.