Comments (6)
First of all, thank you for your reply. The reference code you provided is very helpful to me.
Then, thank you for your answers to the methods used in the visualization.
Yet! Have a good day!
from renet.
Hello,
Thank you very much for the comment and your interest in our work!
I am afraid to say that we do not have concrete plans to update the implementation of visualization at this point.
Here we instead provide some code snippets that may help you to reproduce the attentions.
padim = lambda x, h_max: np.concatenate((x, x.view(-1)[0].copy().expand(1, 3, h_max - x.shape[2], x.shape[3]) / 1e20), axis=0) if x.shape[2] < h_max else x
def attn_heatmap(img_s, img_q, attn_s_qs, attn_q_qs):
h_max = int(np.max([img_s.shape[2], img_q.shape[2]]))
attn_s_qs_normalized = (attn_s_qs - attn_s_qs.min()) / (attn_s_qs.max() - attn_s_qs.min())
attn_q_qs_normalized = (attn_q_qs - attn_q_qs.min()) / (attn_q_qs.max() - attn_q_qs.min())
img_s_heatmap = show_heatmap_on_image(img_s, attn_s_qs_normalized)
img_q_heatmap = show_heatmap_on_image(img_q, attn_q_qs_normalized)
im = np.concatenate((padim(img_s_heatmap, h_max), padim(img_q_heatmap, h_max)), axis=1)
plt.imshow(im)
return plt
def show_heatmap_on_image(img, attn):
heatmap = cv2.applyColorMap(np.uint8(255 * attn), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = np.float32(heatmap) / 255
img_ = np.float32(img) / 255
attended_img = heatmap + np.float32(img_)
attended_img = attended_img / np.max(attended_img)
return np.uint8(255 * attended_img)
Note that we do not use Grad-CAM for Figs.1 and 6 in the paper.
What we visualized (which corresponds to the attn
in the snippet) is the averaged feature activations or 2D attention maps as specified in Sec.5.6. in the paper.
Have a great day! ๐
from renet.
Dear Pro Dahyun Kang:
I'm sorry to bother you again.
I don't understand exactly what settings are required for these parameters (attn_heatmap(img_s, img_q, attn_s_qs, attn_q_qs)).
Can you help me modify the code:
import torch.nn as nn
import numpy as np
import cv2
from common.utils import load_model, setup_run
from models.renet import RENet
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
padim = lambda x, h_max: np.concatenate((x, x.view(-1)[0].copy().expand(1, 3, h_max - x.shape[2], x.shape[3]) / 1e20),
axis=0) if x.shape[2] < h_max else x
def attn_heatmap(img_s, img_q, attn_s_qs, attn_q_qs):
h_max = int(np.max([img_s.shape[2], img_q.shape[2]]))
attn_s_qs_normalized = (attn_s_qs - attn_s_qs.min()) / (attn_s_qs.max() - attn_s_qs.min())
attn_q_qs_normalized = (attn_q_qs - attn_q_qs.min()) / (attn_q_qs.max() - attn_q_qs.min())
img_s_heatmap = show_heatmap_on_image(img_s, attn_s_qs_normalized)
img_q_heatmap = show_heatmap_on_image(img_q, attn_q_qs_normalized)
im = np.concatenate((padim(img_s_heatmap, h_max), padim(img_q_heatmap, h_max)), axis=1)
#print("im shape is ", im.shape)
plt.imshow(im)
return plt
def show_heatmap_on_image(img, attn):
heatmap = cv2.applyColorMap(np.uint8(255 * attn), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = np.float32(heatmap) / 255
#heatmap = np.transpose(heatmap, [2,0,1])
img_ = np.float32(img) / 255
attended_img = heatmap + np.float32(img_)
attended_img = attended_img / np.max(attended_img)
return np.uint8(255 * attended_img)
def ccaother(self, spt, qry):
# shifting channel activations by the channel mean
spt = self.normalize_feature(spt)
qry = self.normalize_feature(qry)
# (S * C * Hs * Ws, Q * C * Hq * Wq) -> Q * S * Hs * Ws * Hq * Wq
corr4d = self.get_4d_correlation_map(spt, qry)
num_qry, way, H_s, W_s, H_q, W_q = corr4d.size()
# corr4d refinement
corr4d = self.cca_module(corr4d.view(-1, 1, H_s, W_s, H_q, W_q))
corr4d_s = corr4d.view(num_qry, way, H_s * W_s, H_q, W_q)
corr4d_q = corr4d.view(num_qry, way, H_s, W_s, H_q * W_q)
# normalizing the entities for each side to be zero-mean and unit-variance to stabilize training
corr4d_s = self.gaussian_normalize(corr4d_s, dim=2)
corr4d_q = self.gaussian_normalize(corr4d_q, dim=4)
# applying softmax for each side
corr4d_s = F.softmax(corr4d_s / self.args.temperature_attn, dim=2)
corr4d_s = corr4d_s.view(num_qry, way, H_s, W_s, H_q, W_q)
corr4d_q = F.softmax(corr4d_q / self.args.temperature_attn, dim=4)
corr4d_q = corr4d_q.view(num_qry, way, H_s, W_s, H_q, W_q)
# suming up matching scores
attn_s = corr4d_s.sum(dim=[4, 5])
attn_q = corr4d_q.sum(dim=[2, 3])
# # applying attention
# spt_attended = attn_s.unsqueeze(2) * spt.unsqueeze(0)
# qry_attended = attn_q.unsqueeze(2) * qry.unsqueeze(1)
return attn_s, attn_q
image_size = 84
resize_size = 92
transform = transforms.Compose([
transforms.Resize([resize_size, resize_size]),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])
if __name__ == '__main__':
# load data
img_s_path = './ImageOutput/Imagesq/Dogs.jpg'
img_q_path = './ImageOutput/Imagesq/Dogq.jpg'
img_s_bs = transform(Image.open(img_s_path).convert('RGB'))
img_q_bs = transform(Image.open(img_s_path).convert('RGB'))
img_s = img_s_bs.cuda()
img_q = img_q_bs.cuda()
img_s = img_s.unsqueeze(0).repeat(1, 1, 1, 1)
img_q = img_q.unsqueeze(0).repeat(1, 1, 1, 1)
# load model for extract feature and attn
args = setup_run(arg_mode='test')
''' define model '''
model = RENet(args).cuda()
pre_path = '/home/grcwoods/WZP/Baseline-renet/MFCS-checkpoints/' + args.dataset + '/' + str(
args.shot) + 'shot-' + str(args.way) + 'way' + '/tiere_bs_5w5s/max_acc.pth'
model = load_model(model, pre_path)
model = nn.DataParallel(model, device_ids=args.device_ids)
model.module.mode = 'encoder'
feature_img_s = model(img_s)
feature_img_q = model(img_q)
model.module.mode = 'ccaother'
attn_s, attn_q = model((feature_img_s, feature_img_q))
# draw cam
print("feature_img_s", type(feature_img_s))
feature_img_s = feature_img_s.cpu().detach().numpy()
feature_img_q = feature_img_q.cpu().detach().numpy()
attn_s = attn_s.cpu().detach().numpy()
attn_q = attn_q.cpu().detach().numpy()
attn_heatmap(img_s_bs, img_q_bs, attn_s, attn_q)
print(">>>>>>>>>>>>>>>>> finish")
I really appreciate your help. I'm looking forward to hearing from you.
from renet.
Hello again
These below are the tensor dimensions I used for visualization:
attn_heatmap(img_s, img_q, attn_s_qs, attn_q_qs)
``` input argument tensor dimensions
img_s.shape: (H, W, 3)
img_q.shape: (H, W, 3)
attn_s_qs.shape: (H, W)
attn_q_qs.shape: (H, W)
```
I hope it helps.
Have a great day! ๐
Best,
Dahyun.
from renet.
Thank you very much, your answer helped me immensely!
Have a great day.
from renet.
I'm glad it was helpful ๐
from renet.
Related Issues (11)
- Conv3d or Conv4d in SCR HOT 1
- Is there an easy way to save checkpoints for Colab user? HOT 1
- Testing Output Network HOT 3
- Inductive or transductive๏ผ HOT 1
- A question HOT 2
- Paper results HOT 2
- About GAP baseline HOT 2
- About the Eq.(4) in the paper HOT 6
- the miniimagenet data set you shared can't be accessed HOT 2
- GAP baseline HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from renet.