Coder Social home page Coder Social logo

MobileNet implementation of CSRA about csra HOT 4 OPEN

ghylander avatar ghylander commented on August 20, 2024
MobileNet implementation of CSRA

from csra.

Comments (4)

Kevinz-code avatar Kevinz-code commented on August 20, 2024

Hi, @ghylander
Thanks for your question and implementation.

Here are three key steps in CSRA module:

  1. generate the attention score s^i_j and class specific feature a^i (Eq. 2 and Eq. 3)
  2. combine each a^i with avepool feature g to get f^i
  3. sent f^i to get the i-th class logit. (Eq. 6)

Input of CSRA: feature before avgpool (B x dimention x H x W)
Ourput of CSRA: the logit (B x C)
Step 1-3 can be expressed by Eq. 8 and is implemented in 'Class CSRA' in pipeline/csra.py. Can refer to #5 for our reply.

About MobileNetV3, you can apply CSRA 3 steps to replace the Avg Pool on the (7,7,960) tensor, then the final logit will be the output. There might be a little acc drop since the H-Swish struture will be discarded in this case.

For more details, refer to our paper.

Best,

from csra.

ghylander avatar ghylander commented on August 20, 2024

Thanks for the reply, I think it made things clearer.
If I'm not mistaken, the diagram below is the full "tensor-flow" (heh, no pun intended) of the CSRA module implemented in the code and the one used in the paper:

image
(made a mistake and fixed the diagram)

I do have a question regarding the code, I work mainly with tensorflow, and I'm not fully familiar with pytorch workflows/structures. What does this function do?:

score = score.flatten(2)

I can assume it flattens the dimensions of a vector, but why is this done?
same applies to the transpose of the batch normalization:

torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1)

Why is the result of the normalization transposed before being applied to the output of the fully connected layer?

from csra.

ghylander avatar ghylander commented on August 20, 2024

As an update to this, I managed to implement almost all of CSRA in my MobileNetV3 model.
Had to dive deep into both the pytorch and tensorflow docs to fully translate one into the other.
Only bit I'm currently missing is the multi-head attention, which also connects with a doubt I have with the pipeline/csra.py file.

As far as I understood, using multi-head attention creates T-number of parallel heads, then CSRA is applied within each head with temperature = T. This results in T-number of vectors of shape (Batch, Classes). Then, all of these vectors are added together element by element and the resulting tensor is sent to a sigmoid activation function.

Now, in the case where num_heads = 1, the resulting tensor of the single head is also to the sigmoid activation too.

Is that correct?

from csra.

ghylander avatar ghylander commented on August 20, 2024

Update on this, I had to put this on hold and am returning to it now. I managed to implement the drop-in CSRA module in TensorFlow (V2.9) for the MobileNetV3-Large backbone. I half-managed to implement the trainable block.

Can you clarify some stuff for me? Some of my trouble comes from the translation from pyTorch -> TensorFlow:

1.- When declaring the CSRA class in pipeline/csra.py , line 6. In the forward() method, in line 18: what you are performing is a weight normalisation of the features vector, isn't it?

2.- Then, on line 19, you flatten the resulting vector height and width. I'm not familiar with pyTorch's flatten() method, but I assume it works just like NumPy's, by 'appending' the nested dimensions one after the other, correct?

3.- In line 20, you compute the mean of the HxW array. Looking at figure 1 in the paper, this seems to be equivalent to an average pooling operation. Is this correct?

4.- Lastly, your current implementation works with logits. What impact would it have on the CSRA performance to implement a activation function (softmax or sigmoid) to the output logits (needless to say, the loss function used would account for this)? i.e.:

score = self.head(x) / torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1)
score = score.flatten(2)
base_logit = torch.mean(score, dim=2)
score_soft = self.softmax(score * self.T)
att_logit = torch.sum(score * score_soft, dim=2)

output = torch.nn..Sigmoid()(base_logit + self.lam * att_logit)

Here's my current CSRA drop in implementation in TF v2.9:

# Defining model input tensor shape, (None) means dynamic shape
inputs = tf.keras.Input(shape=(None, None, 3))

# Base model is the backbone, features is the resulting vector with shape (batch, H, W, d)=(32, 7, 7, 960)
features = base_model(inputs, training=False)

# Applying a Fully Connected layer to the backbone output
attentions = tf.keras.layers.Conv2D(1280, kernel_size=1, padding='same', use_bias=False)(features)
# Applying Batch Normalization to the FC layer output
attentions = tf.keras.layers.BatchNormalization()(attentions)
# Applying Average Pooling to the normalized FC output
avg_attentions = tf.keras.layers.GlobalAveragePooling2D()(attentions)
# Applying Max Pooling to the normalized FC output
max_attentions = tf.keras.layers.GlobalMaxPooling2D()(attentions)
# Computing CSRA logit output with lambda = 0.2
csra_output = (avg_attentions + max_attentions*0.2)
# Applying a dropout layer
csra_output_dropout = tf.keras.layers.Dropout(0.2)(csra_output)
# Applying a 2nd FC layer with sigmoid activation
outputs_csra = tf.keras.layers.Dense(2, activation='sigmoid', kernel_initializer='random_uniform', bias_initializer='zeros')(csra_output_dropout)

model_csra = tf.keras.Model(inputs, outputs_csra)

Does this look correct to you?

from csra.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.