Coder Social home page Coder Social logo

gandtr's Introduction

GAN domain translation for recognition

arXiv | Paper | Video (5m) | Poster


Codebase for the publication:

Dark Side Augmentation: Generating Diverse Night Examples for Metric Learning [arXiv]. Albert Mohwald, Tomas Jenicek and Ondřej Chum. In International Conference on Computer Vision (ICCV), 2023.

This repository builds on top of image retrieval implemented in mdir and cirtorch and adapts CycleGAN and CUT for image-to-image translation.

train_and_finetune


Pretrained models

Downloads

Day-to-night generator models

Original day CycleGAN night HEDNGAN night

More examples

Model Weights
CycleGAN (day-to-night) Download
HEDNGAN (day-to-night) Download

Embedding models

Model Avg Tokyo ROxf RPar Weights Whitening
GeM VGG16 CycleGAN 74.0 90.2 60.7 71.0 Download Download
GeM VGG16 HEDNGAN 73.5 88.8 61.1 70.7 Download Download
GeM ResNet-101 CycleGAN 78.4 92.0 66.8 76.4 Download Download
GeM ResNet-101 HEDNGAN 78.4 91.7 66.6 76.8 Download Download

All models are pretrained on Retrieval-SfM 120k.

Torch Hub

To use any pretrained model, please follow PyTorch installation instructions.

import torch

# Day-to-night generators
cyclegan = torch.hub.load('mohwald/gandtr', 'cyclegan', pretrained=True)
hedngan = torch.hub.load('mohwald/gandtr', 'hedngan', pretrained=True)

# Image descriptors
gem_vgg16_cyclegan = torch.hub.load('mohwald/gandtr', 'gem_vgg16_cyclegan', pretrained=True)
gem_vgg16_hedngan = torch.hub.load('mohwald/gandtr', 'gem_vgg16_hedngan', pretrained=True)
gem_resnet101_cyclegan = torch.hub.load('mohwald/gandtr', 'gem_resnet101_cyclegan', pretrained=True)
gem_resnet101_hedngan = torch.hub.load('mohwald/gandtr', 'gem_resnet101_hedngan', pretrained=True)

Models initialized this way are pretrained and loaded on GPU by default. If do not want to load pretrained weights, pass pretrained=False; to load the model on e.g. CPU, pass device="cpu".

Important

The expected input of all descriptor models listed above is a batch of normalized images after CLAHE transform. To recommended way how to obtain the image preprocessing transforms (suitable for dataset loader) is demonstrated in the snippet below:

>>> import torch
>>> model = torch.hub.load('mohwald/gandtr', 'gem_vgg16_hedngan')
>>> model.transform
Compose(
    Pil2Numpy()
    ApplyClahe(clip_limit=1.0, grid_size=8, colorspace=lab)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], strict_shape=True)
)

Inference

A single global descriptor can be extracted from an image simply with:

import torch
from PIL import Image

model = torch.hub.load('mohwald/gandtr', 'gem_vgg16_hedngan')
with open("orloj.jpg", 'rb') as f:
    image = Image.open(f).convert("RGB")
inputs = model.transform(image).unsqueeze(0)
with torch.no_grad():
    vec = model(inputs)
print(vec)

The output is 512-sized L2-normalized whitened vector. For the orloj.jpg, you should obtain the vector ending very close to:

        -6.3813e-03, -2.2138e-04,  2.0179e-03,  1.9477e-02,  6.6316e-03,
         1.0677e-02,  1.0847e-02], device='cuda:0')

Installation

  1. Install ffmpeq and graphviz, if you do not have it; ffmpeg is required for OpenCV, graphviz allows to draw network architecture.
sudo apt-get install ffmpeg libsm6 libxext6 graphviz   # Ubuntu and Debian-based system
sudo dnf install ffmpeg libSM graphviz                 # RHEL, Fedora, and CentOS-based system 
  1. Clone this repository: [email protected]:mohwald/gandtr.git && cd gandtr
  2. Install dependencies: pip install -r requirements.txt
  3. (Optional) set environment variables:
    • ${CIRTORCH_ROOT}, where you want to store image data and model weights. All necesarry data for the evaluation and training are automatically downloaded there.
    • ${CUDA_VISIBLE_DEVICES}, set to a single gpu, e.g. export CUDA_VISIBLE_DEVICES=0
  4. Go to mdir/examples

General scenario format and execution

Inside mdir/examples, each experiment can be executed by script perform_scenario.py, that runs yaml scenarios based on this structure:

TARGET:
  1_step:  # first step parameters dictionary
      ...
  2_step:  # second step parameters dictionary
      ...
  ...

Nested dictionary keys can be used in parameters and variables (nested keys are separated by a dot). Bash-style variables are supported within a TARGET, e.g. ${1_step.section.key}. A special variable ${SCENARIO_NAME} denotes the name of the executed scenario (last scenario name, if scenarios are overlayed).

A scenario is executed with perform_scenario.py as:

python3 perform_scenario.py TARGET SCENARIO_NAME_1.yml [SCENARIO_NAME_2.yml]...

Scenarios can overlay, so that all variables of SCENARIO_NAME_1 are replaced by variables from SCENARIO_NAME_2.

Evaluation

All scenarios for evaluation are located inside iccv23/eval.

To evaluate a model from ICCV23 paper, e.g. HED-N-GAN method with GeM VGG16 backbone, run:

python3 perform_scenario.py eval iccv23/eval/hedngan.yml

Warning

Oxford and Paris buildings dataset images are no longer available at the original sources and thus cannot be downloaded automatically. One option is to download images from Kaggle (requires registration). Images should be placed inside ${CIRTORCH_ROOT}/data/test/{oxford5k, paris6k}/jpg, without any nested directories.

To change the GAN generator used in the augmentation, use different scenario with the corresponding generator name. To change the embedding backbone, change eval to eval_r101 to evaluate on GeM ResNet-101. With these options, you should get the following results:

VGG-16 Backbone (eval)

Model Tokyo ROxf RPar
hedngan 88.8 61.1 70.7
cyclegan 90.2 60.7 71.0

ResNet-101 Backbone (eval_r101)

Model Tokyo ROxf RPar
hedngan 91.7 66.6 76.8
cyclegan 92.0 66.8 76.4

Training

All scenarios for training from scratch are located inside iccv23/train.

GAN generator training

To train a GAN generator from scratch, e.g. HED-N-GAN, run:

python3 perform_scenario.py train iccv23/train/hedngan.yml

To change the GAN model, replace the yaml scenario with the scenario corresponding to the model name, e.g. hedngan.yml with cyclegan.yml, etc.

(Optional) After the generator training is finished, arbitrary images can be outputted by the trained generator given a list of image paths from standard input and executing the output target:

python3 perform_scenario.py output iccv23/train/hedngan.yml

Metric learning

To finetune an embedding network for image retrieval, which uses augmentation with HED-N-GAN generator, run:

python3 perform_scenario.py finetune iccv23/train/hedngan.yml

This command will both finetune the embedding model and consequently evaluate it.

To change the backbone used for the finetuning, replace finetune with finetune_r101 for GeM ResNet-101.

gandtr's People

Contributors

mohwald avatar

Stargazers

pompom avatar Bingxi Liu avatar 爱可可-爱生活 avatar Realcat avatar Gabriele Berton avatar Tomas Jenicek avatar Hui Wu avatar

Watchers

Tomas Jenicek avatar Kostas Georgiou avatar  avatar

gandtr's Issues

Pre-trained models

Hello there, I was wondering where did you get the pre-trained HED model from your config /mnt/fry2/landmarkdb/models/pytorch/weights/hed/sniklaus_github.pth, the one I get from the https://github.com/sniklaus/pytorch-hed has different model layer names so I can't load it directly on this code, did you train a model or simply changed the names of the layers from this one?

DSA’s TorchHub

Hellod, Mohwald.
I want to conduct some comparative experiments with DSA, which is very important to me.
I tried your sample code to reproduce it, however I ran into some difficulties.
I'm going to add your method to the project for testing. (https://github.com/gmberton/VPR-methods-evaluation)
Many famous methods have been added to the models folder of this project.

import torch
import torchvision.transforms as tfm

from models import utils


class DSAModel(torch.nn.Module):
    def __init__(self, device='cuda'):
        super().__init__()        
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        self.net = torch.hub.load('mohwald/gandtr', 'gem_vgg16_hedngan').to(self.device)
        self.state_dict = torch.load("/home/ubuntu/.cache/torch/hub/checkpoints/hedngan_embed_vgg16.pth")
        self.net.model.load_state_dict(self.state_dict['model_state'])

        self.un_normalize = utils.UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        self.normalize = tfm.Normalize(mean=[0.48501960784313836, 0.4579568627450961, 0.4076039215686255],
                                       std=[0.00392156862745098, 0.00392156862745098, 0.00392156862745098])

    def forward(self, images):
        images = self.normalize(self.un_normalize(images))
        descriptors = self.net(images)
        return descriptors
  1. I cannot load the model onto the GPU.
  2. Are the normalization parameters I use fair for DSA?
  3. Do you have any other comments about this code?

Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.