Coder Social home page Coder Social logo

ucdvision / cgc Goto Github PK

View Code? Open in Web Editor NEW
18.0 2.0 2.0 364 KB

Official PyTorch code for the CVPR 2022 paper - Consistent Explanations by Contrastive Learning

License: MIT License

Python 99.77% Makefile 0.21% Shell 0.02%
cgc deep-learning explainable-ai explainable-models contrastive-learning self-supervised-learning

cgc's Introduction

Consistent-Explanations-by-Contrastive-Learning

Official PyTorch code for CVPR 2022 paper - Consistent Explanations by Contrastive Learning

Post-hoc explanation methods, e.g., Grad-CAM, enable humans to inspect the spatial regions responsible for a particular network decision. However, it is shown that such explanations are not always consistent with human priors, such as consistency across image transformations. Given an interpretation algorithm, e.g., Grad-CAM, we introduce a novel training method to train the model to produce more consistent explanations. Since obtaining the ground truth for a desired model interpretation is not a well-defined task, we adopt ideas from contrastive self-supervised learning, and apply them to the interpretations of the model rather than its embeddings. We show that our method, Contrastive Grad-CAM Consistency (CGC), results in Grad-CAM interpretation heatmaps that are more consistent with human annotations while still achieving comparable classification accuracy. Moreover, our method acts as a regularizer and improves the accuracy on limited-data, fine-grained classification settings. In addition, because our method does not rely on annotations, it allows for the incorporation of unlabeled data into training, which enables better generalization of the model.

Teaser image


Bibtex

@InProceedings{Pillai_2022_CVPR,
author = {Pillai, Vipin and Abbasi Koohpayegani, Soroush and Ouligian, Ashley and Fong, Dennis and Pirsiavash, Hamed},
title = {Consistent Explanations by Contrastive Learning},
booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022}
}

Pre-requisites

  • Pytorch 1.3 - Please install PyTorch and CUDA if you don't have it installed.

Datasets

Training

Train and evaluate a ResNet50 model on the ImageNet dataset using our CGC loss

CUDA_VISIBLE_DEVICES=0,1,2,3 python train_eval_cgc.py /datasets/imagenet -a resnet50 -p 100 -j 8 -b 256 --lr 0.1 --lambda 0.5 -t 0.5 --save_dir <SAVE_DIR> --log_dir <LOG_DIR>

Train and evaluate a ResNet50 model on 1pc labeled subset of ImageNet dataset and the rest as unlabeled dataset. We initialize the model from SwAV

For the below command, <PATH_TO_SWAV_MODEL_PRETRAINED> can be downloaded from the github directory of SwAV - https://github.com/facebookresearch/swav We use the model checkpoint listed on the first row (800 epochs, 75.3% ImageNet top-1 acc.) of the Model Zoo of the above repository.

CUDA_VISIBLE_DEVICES=0,1 python train_imagenet_1pc_swav_cgc_unlabeled.py <PATH_TO_1%_IMAGENET> -a resnet50 -b 128 -j 8 --lambda 0.25 -t 0.5 --epochs 50 --lr 0.02 --lr_last_layer 5 --resume <PATH_TO_SWAV_MODEL_PRETRAINED> --save_dir <SAVE_DIR> --log_dir <LOG_DIR> 2>&1 | tee <PATH_TO_CMD_LOG_FILE>

Checkpoints

  • ResNet50 model pre-trained on ImageNet - link

Evaluation

Evaluate model checkpoint using Content Heatmap (CH) evaluation metric

We use the evaluation code adapted from the TorchRay framework.

  • Change directory to TorchRay and install the library. Please refer to the TorchRay repository for full documentation and instructions.

    • cd TorchRay
    • python setup.py install
  • Change directory to TorchRay/torchray/benchmark

    • cd torchray/benchmark
  • For the ImageNet & CUB-200 datasets, this evaluation requires the following structure for validation images and bounding box xml annotations

    • <PATH_TO_FLAT_VAL_IMAGES_BBOX>/val/*.JPEG - Flat list of validation images
    • <PATH_TO_FLAT_VAL_IMAGES_BBOX>/annotation/*.xml - Flat list of annotation xml files
Evaluate ResNet50 models trained on the full ImageNet dataset
CUDA_VISIBLE_DEVICES=0 python evaluate_imagenet_gradcam_energy_inside_bbox.py <PATH_TO_FLAT_VAL_IMAGES_BBOX> -j 0 -b 1 --resume <PATH_TO_SAVED_CHECKPOINT_FILE> --input_resize 448 -a resnet50
Evaluate ResNet50 models trained on the CUB-200 fine-grained dataset
CUDA_VISIBLE_DEVICES=0 python evaluate_finegrained_gradcam_energy_inside_bbox.py <PATH_TO_FLAT_VAL_IMAGES_BBOX> --dataset cub -j 0 -b 1 --resume <PATH_TO_SAVED_CHECKPOINT_FILE> --input_resize 448 -a resnet50
Evaluate ResNet50 models trained from SwAV initialized models with 1pc labeled subset of ImageNet and rest as unlabeled
CUDA_VISIBLE_DEVICES=0 python evaluate_swav_imagenet_gradcam_energy_inside_bbox.py <PATH_TO_IMAGENET_VAL_FLAT> -j 0 -b 1 --resume <PATH_TO_SAVED_CHECKPOINT_FILE> --input_resize 448 -a resnet50

Evaluate model checkpoint using Insertion AUC (IAUC) evaluation metric

Change to directory RISE/ and follow the below commands:

Evaluate pre-trained ResNet50 model
CUDA_VISIBLE_DEVICES=0 python evaluate_auc_metrics.py --pretrained
Evaluate ResNet50 model trained using our CGC method
CUDA_VISIBLE_DEVICES=0 python evaluate_auc_metrics.py --ckpt-path <PATH_TO_SAVED_CHECKPOINT_FILE>

License

This project is licensed under the MIT License.

cgc's People

Contributors

vipinpillai avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

cgc's Issues

ImageFolder __getitem__ method returning incorrect horizontal_flip transformation parameter

Current Behavior:

The following line performs a horizontal flip on the augmented image with a 50% probability, but due to self.hor_flip(aug_sample) not being deterministic, the augmented image does not correspond to the hor_flip parameter:

aug_sample = self.hor_flip(aug_sample)

Expected Behavior:

The hor_flip parameter should be True iff the augmented image is a flipped version of the sample (possibly with some crop).

This can be done by setting self.hor_flip = tvf.hflip

Steps To Reproduce:

The following code was used to visualize the tensors and verify that sometimes the parameter does not correspond to the augmented image:

import matplotlib.pyplot as plt
import numpy as np

def display_tensors(tensor1, tensor2, hor_flip):
    fig, axs = plt.subplots(1, 2, figsize=(10, 10))

    for i, tensor in enumerate([tensor1, tensor2]):
        # Convert the tensor to numpy array
        image_np = tensor.numpy()

        # Scale the values to [0, 1] range
        image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())

        # Transpose the numpy array if necessary
        if image_np.shape[0] == 3:  # Check if the image tensor is in the format (channels, height, width)
            image_np = np.transpose(image_np, (1, 2, 0))  # Transpose to (height, width, channels)

        # Display the image
        axs[i].imshow(image_np)
        axs[i].set_title(f"Flipped? {hor_flip}")

    plt.show(block=True)

Anything else:

I am using your CGC paper for reference: https://arxiv.org/pdf/2110.00527.pdf

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.