Coder Social home page Coder Social logo

sabunculab / nwhead Goto Github PK

View Code? Open in Web Editor NEW

This project forked from alanqrwang/nwhead

0.0 0.0 0.0 4.16 MB

Code for the Nadaraya-Watson Head - an interpretable/explainable, nonparametric classification head which can be used with any neural network

Python 100.00%

nwhead's Introduction

Nadaraya-Watson (NW) Head

Repository containing training and evaluation code for the NW head - an interpretable/explainable, nonparametric classification head which can be used with any neural network. Architecture link to paper

NW Head

The NW head module is in nwhead/nw.py. In its simplest form, the NW head code is:

import torch
import torch.nn as nn
import torch.nn.functional as F

class NWHead(nn.Module):
    def forward(self,
                query_feats,
                support_feats,
                support_labels):
        """
        Computes Nadaraya-Watson prediction.
        Returns (softmaxed) predicted probabilities.
        Args:
            query_feats: (b, embed_dim)
            support_feats: (b, num_support, embed_dim)
            support_labels: (b, num_support, num_classes)
        """
        query_feats = query_feats.unsqueeze(1)

        scores = -torch.cdist(query_feats, support_feats)
        probs = F.softmax(scores, dim=-1)
        return torch.bmm(probs, support_labels).squeeze(1)

Usage

The submodule in nwhead/ is designed to be portable, so that it can be inserted in an existing project flexibly. An example of usage can be found in train.py. For example, to train an NW head with ResNet-18 backbone:

import torch
from nwhead.nw import NWNet

# Data
train_dataset = ...
val_dataset = ...
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)
num_classes = train_dataset.num_classes

# Feature extractor
feature_extractor = load_model('resnet18', num_classes)
feat_dim = 512

# NW Head
network = NWNet(feature_extractor, 
                train_dataset,
                num_classes,
                feat_dim,
                use_nll_loss=True)
network.train()

# Loss and optimizer
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.SGD(network.parameters(), lr=1e-3)

# Training loop
for img, label in train_loader:
    img = img.float().to(device)
    label = label.to(device)
    optimizer.zero_grad()
    with torch.set_grad_enabled(True):
        output = network(img, label)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

To perform evaluation, use the predict() method and pass in your desired inference mode. Make sure to call precompute() beforehand.

network.eval()
network.precompute()
mode = 'full'

for img, label in val_loader:
    img, label = batch
    img = img.float().to(device)
    label = label.to(device)
    optimizer.zero_grad()
    with torch.set_grad_enabled(False):
        output = network.predict(img, mode)
        loss = criterion(output, label)

Interpretability and Explainability

Interpretablity via weights

Ranking support images by the scores variable enables sorting the support images by similarity, as in this figure: Similarities

Explainability via support influence

The NW head naturally lends itself to a notion of “support influence" (Section 3.4 in the paper) which finds the most helpful and most harmful examples in the support set for a given query image. The function to compute this is given in util/metric.py:

def support_influence(softmaxes, qlabels, sweights, slabels):
    '''
    Influence is defined as L(rescaled_softmax, qlabel) - L(softmax, qlabel).
    Positive influence => removing support image increases loss => support image was helpful
    Negative influence => removing support image decreases loss => support image was harmful
    bs should be 1.
    
    softmaxes: (bs, num_classes)
    qlabel: One-hot encoded query label (bs, num_classes)
    sweights: Weights between query and each support (bs, num_support)
    slabels: One-hot encoded support label (bs, num_support, num_classes)
    '''
    batch_influences = []
    bs = len(softmaxes)
    for bid in range(bs):
        softmax = softmaxes[bid]
        qlabel = qlabels[bid]
        sweight = sweights[bid]
        
        qlabel_cat = qlabel.argmax(-1).item()
        slabels_cat = slabels.argmax(-1)
        
        p = softmax[qlabel_cat]
        indicator = (slabels_cat==qlabel_cat).long()
        influences = torch.log((p - p*sweight)/(p - sweight*indicator))
    
        batch_influences.append(influences[None])
    return torch.cat(batch_influences, dim=0)

This figure shows results of ranking support images using support influence by most helpful and most harmful: Influences

Training

Example command for training NW head:

python train.py \
  --models_dir out/ \ # Directory to save model outputs
  --data_dir ... \ # Directory where dataset lives
  --dataset bird  \ # Dataset to use
  --arch resnet18 \ # Feature extractor, $g_\theta$ in paper
  --train_method nwhead \ # Model to train, choose from [fchead, nwhead]
  --batch_size 8 \
  --lr 1e-2 \
  --num_epochs 1000 \
  --scheduler_milestones 500 750 \ # Epoch milestones to decrease lr via scheduler
  --n_way 10 # Use to limit number of classes in support for computational efficiency

This script will train for 1000 epochs and perform evaluation at the end of each epoch using random, full, and cluster inference modes. Optionally, toggle the --use_wandb flag to log training results to Weights & Biases.

Invariant Representation Learning with NW Head

The code supports learning invariant representations across different environments by conditioning the support set on a single environment during training. NWIRM To achieve this, specify --train_type irm in the script. Datasets must have additional metadata representing which environment each image originates from. This can be accomplished in 2 ways:

  1. Pass an env_array variable to NWNet which is a 1d environment indicator array of same length as the training set.
  2. Pass a support_dataset to NWNet which is a list of separate Pytorch datasets, where each dataset is composed of data from a single environment.

Requirements

This code was run and tested on an Nvidia A6000 GPU with the following dependencies:

  • python 3.7.11
  • torch 1.10.1
  • torchvision 0.11.2
  • numpy 1.21.5

Citation

If you use NW head or some part of the code, please consider citing:

Alan Q. Wang and Mert R. Sabuncu, "A Flexible Nadaraya-Watson Head Can Offer Explainable and Calibrated Classification" (TMLR 2023)

@article{
    wang2022nwhead,
    title={A Flexible Nadaraya-Watson Head Can Offer Explainable and Calibrated Classification},
    author={Alan Q. Wang and Mert R. Sabuncu},
    journal={Transactions on Machine Learning Research},
    issn={2835-8856},
    year={2022},
    url={https://openreview.net/forum?id=iEq6lhG4O3},
}

For code related to invariant representation learning with NW head, please consider citing:

Alan Q. Wang, Minh Nguyen, and Mert R. Sabuncu, "Learning Invariant Representations with a Nonparametric Nadaraya-Watson Head" (NeurIPS 2023)

@article{
    wang2023nwheadirm,
    title={Learning Invariant Representations with a Nonparametric Nadaraya-Watson Head},
    author={Alan Q. Wang and Minh Nguyen and Mert R. Sabuncu},
    journal={NeurIPS},
    year={2023},
}

nwhead's People

Contributors

alanqrwang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.