Coder Social home page Coder Social logo

Comments (6)

woodszp avatar woodszp commented on May 24, 2024 1

First of all, thank you for your reply. The reference code you provided is very helpful to me.
Then, thank you for your answers to the methods used in the visualization.
Yet! Have a good day!

from renet.

dahyun-kang avatar dahyun-kang commented on May 24, 2024

Hello,

Thank you very much for the comment and your interest in our work!
I am afraid to say that we do not have concrete plans to update the implementation of visualization at this point.
Here we instead provide some code snippets that may help you to reproduce the attentions.

padim = lambda x, h_max: np.concatenate((x, x.view(-1)[0].copy().expand(1, 3, h_max - x.shape[2], x.shape[3]) / 1e20), axis=0) if x.shape[2] < h_max else x

def attn_heatmap(img_s, img_q, attn_s_qs, attn_q_qs):
    h_max = int(np.max([img_s.shape[2], img_q.shape[2]]))
    attn_s_qs_normalized = (attn_s_qs - attn_s_qs.min()) / (attn_s_qs.max() - attn_s_qs.min())
    attn_q_qs_normalized = (attn_q_qs - attn_q_qs.min()) / (attn_q_qs.max() - attn_q_qs.min())
    img_s_heatmap = show_heatmap_on_image(img_s, attn_s_qs_normalized)
    img_q_heatmap = show_heatmap_on_image(img_q, attn_q_qs_normalized)
    im = np.concatenate((padim(img_s_heatmap, h_max), padim(img_q_heatmap, h_max)), axis=1)
    plt.imshow(im)
    return plt

def show_heatmap_on_image(img, attn):
    heatmap = cv2.applyColorMap(np.uint8(255 * attn), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255
    img_ = np.float32(img) / 255
    attended_img = heatmap + np.float32(img_)
    attended_img = attended_img / np.max(attended_img)
    return np.uint8(255 * attended_img)

Note that we do not use Grad-CAM for Figs.1 and 6 in the paper.
What we visualized (which corresponds to the attn in the snippet) is the averaged feature activations or 2D attention maps as specified in Sec.5.6. in the paper.

Have a great day! ๐Ÿ˜ƒ

from renet.

woodszp avatar woodszp commented on May 24, 2024

Dear Pro Dahyun Kang:
I'm sorry to bother you again.
I don't understand exactly what settings are required for these parameters (attn_heatmap(img_s, img_q, attn_s_qs, attn_q_qs)).
Can you help me modify the code:

import torch.nn as nn
import numpy as np
import cv2
from common.utils import load_model, setup_run
from models.renet import RENet
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image

padim = lambda x, h_max: np.concatenate((x, x.view(-1)[0].copy().expand(1, 3, h_max - x.shape[2], x.shape[3]) / 1e20),
                                        axis=0) if x.shape[2] < h_max else x


def attn_heatmap(img_s, img_q, attn_s_qs, attn_q_qs):
    h_max = int(np.max([img_s.shape[2], img_q.shape[2]]))
    attn_s_qs_normalized = (attn_s_qs - attn_s_qs.min()) / (attn_s_qs.max() - attn_s_qs.min())
    attn_q_qs_normalized = (attn_q_qs - attn_q_qs.min()) / (attn_q_qs.max() - attn_q_qs.min())
    img_s_heatmap = show_heatmap_on_image(img_s, attn_s_qs_normalized)
    img_q_heatmap = show_heatmap_on_image(img_q, attn_q_qs_normalized)
    im = np.concatenate((padim(img_s_heatmap, h_max), padim(img_q_heatmap, h_max)), axis=1)
    #print("im shape is ", im.shape)
    plt.imshow(im)
    return plt


def show_heatmap_on_image(img, attn):
    heatmap = cv2.applyColorMap(np.uint8(255 * attn), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255
    #heatmap = np.transpose(heatmap, [2,0,1])
    img_ = np.float32(img) / 255
    attended_img = heatmap + np.float32(img_)
    attended_img = attended_img / np.max(attended_img)
    return np.uint8(255 * attended_img)

def ccaother(self, spt, qry):
    # shifting channel activations by the channel mean
    spt = self.normalize_feature(spt)
    qry = self.normalize_feature(qry)

    # (S * C * Hs * Ws, Q * C * Hq * Wq) -> Q * S * Hs * Ws * Hq * Wq
    corr4d = self.get_4d_correlation_map(spt, qry)
    num_qry, way, H_s, W_s, H_q, W_q = corr4d.size()

    # corr4d refinement
    corr4d = self.cca_module(corr4d.view(-1, 1, H_s, W_s, H_q, W_q))
    corr4d_s = corr4d.view(num_qry, way, H_s * W_s, H_q, W_q)
    corr4d_q = corr4d.view(num_qry, way, H_s, W_s, H_q * W_q)

    # normalizing the entities for each side to be zero-mean and unit-variance to stabilize training
    corr4d_s = self.gaussian_normalize(corr4d_s, dim=2)
    corr4d_q = self.gaussian_normalize(corr4d_q, dim=4)

    # applying softmax for each side
    corr4d_s = F.softmax(corr4d_s / self.args.temperature_attn, dim=2)
    corr4d_s = corr4d_s.view(num_qry, way, H_s, W_s, H_q, W_q)
    corr4d_q = F.softmax(corr4d_q / self.args.temperature_attn, dim=4)
    corr4d_q = corr4d_q.view(num_qry, way, H_s, W_s, H_q, W_q)

    # suming up matching scores
    attn_s = corr4d_s.sum(dim=[4, 5])
    attn_q = corr4d_q.sum(dim=[2, 3])

    # # applying attention
    # spt_attended = attn_s.unsqueeze(2) * spt.unsqueeze(0)
    # qry_attended = attn_q.unsqueeze(2) * qry.unsqueeze(1)

    return attn_s, attn_q

image_size = 84
resize_size = 92

transform = transforms.Compose([
    transforms.Resize([resize_size, resize_size]),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
                         np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])

if __name__ == '__main__':
    # load data
    img_s_path = './ImageOutput/Imagesq/Dogs.jpg'
    img_q_path = './ImageOutput/Imagesq/Dogq.jpg'
    img_s_bs = transform(Image.open(img_s_path).convert('RGB'))
    img_q_bs = transform(Image.open(img_s_path).convert('RGB'))
    img_s = img_s_bs.cuda()
    img_q = img_q_bs.cuda()
    img_s = img_s.unsqueeze(0).repeat(1, 1, 1, 1)
    img_q = img_q.unsqueeze(0).repeat(1, 1, 1, 1)

    # load model for extract feature and attn
    args = setup_run(arg_mode='test')
    ''' define model '''
    model = RENet(args).cuda()
    pre_path = '/home/grcwoods/WZP/Baseline-renet/MFCS-checkpoints/' + args.dataset + '/' + str(
        args.shot) + 'shot-' + str(args.way) + 'way' + '/tiere_bs_5w5s/max_acc.pth'
    model = load_model(model, pre_path)
    model = nn.DataParallel(model, device_ids=args.device_ids)

    model.module.mode = 'encoder'
    feature_img_s = model(img_s)
    feature_img_q = model(img_q)

    model.module.mode = 'ccaother'
    attn_s, attn_q = model((feature_img_s, feature_img_q))

    # draw cam
    print("feature_img_s", type(feature_img_s))
    feature_img_s = feature_img_s.cpu().detach().numpy()
    feature_img_q = feature_img_q.cpu().detach().numpy()
    attn_s = attn_s.cpu().detach().numpy()
    attn_q = attn_q.cpu().detach().numpy()
    attn_heatmap(img_s_bs, img_q_bs, attn_s, attn_q)

    print(">>>>>>>>>>>>>>>>> finish")

I really appreciate your help. I'm looking forward to hearing from you.

from renet.

dahyun-kang avatar dahyun-kang commented on May 24, 2024

Hello again

These below are the tensor dimensions I used for visualization:

attn_heatmap(img_s, img_q, attn_s_qs, attn_q_qs)
``` input argument tensor dimensions
img_s.shape: (H, W, 3)
img_q.shape: (H, W, 3)
attn_s_qs.shape: (H, W)
attn_q_qs.shape: (H, W)
```

I hope it helps.
Have a great day! ๐Ÿ˜ƒ

Best,
Dahyun.

from renet.

woodszp avatar woodszp commented on May 24, 2024

Thank you very much, your answer helped me immensely!
Have a great day.

from renet.

dahyun-kang avatar dahyun-kang commented on May 24, 2024

I'm glad it was helpful ๐Ÿ˜ƒ

from renet.

Related Issues (11)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.