Coder Social home page Coder Social logo

Comments (9)

soskek avatar soskek commented on July 18, 2024 2

An attention mask is multiplied to attention weights (not vectors themselves). So the dimension of vectors have no connection with the shape of masks here.
See https://github.com/soskek/attention_is_all_you_need/blob/master/net.py#L162-L165

It will be easier to understand if read it with replacing "vector" to "token" or "word". Additionally, I can rephrase some parts.

The value at index (b, i, j) (of whole mask array) indicates whether, for bth pair in the batch, ith source word can attend jth target word or not.
For example, padding tokens (which have -1 instead of word index) are not allowed to be attended, so the corresponding values are False.

from attention_is_all_you_need.

machanic avatar machanic commented on July 18, 2024 1

Yes, Thank you , I am understand now, avoid attending padding -1 is the reason why writes:

def make_attention_mask(source_block, target_block):
        mask = (target_block[:, None, :] >= 0) * \
            (source_block[:, :, None] >= 0) 
        return mask

from attention_is_all_you_need.

soskek avatar soskek commented on July 18, 2024

Have you read the paper? See 3.1.

We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position i can depend only on the known outputs at positions less than i.

make_history_mask corresponds to this part.

from attention_is_all_you_need.

machanic avatar machanic commented on July 18, 2024

@soskek You mean the purpose of mask is to let the input of decoder part only knows information of positions that earlier than position i ???

can you explain me about the following code, I read your code again, just can't got the point to understand?

xx_mask = self.make_attention_mask(x_block, x_block)
xy_mask = self.make_attention_mask(y_in_block, x_block)
yy_mask = self.make_attention_mask(y_in_block, y_in_block)
yy_mask *= self.make_history_mask(y_in_block)

from attention_is_all_you_need.

machanic avatar machanic commented on July 18, 2024

make_attention_mask will return shape of (batch, source_length, target_length)
what is it?
history_mask is a true value located in Lower triangular matrix then np.tile to shape of (batch, length, length)
Can you give me some information of your mystery code?

from attention_is_all_you_need.

soskek avatar soskek commented on July 18, 2024

Consider batch=1 case and ignore axis=0 for simplicity, i.e. shape is (source_length, target_length).
In a naive formulation of attention mechanism, we have a "source" vector and "target" vectors, where the former attends each of the latter. (*, target_length)
Furthermore, we can calculate attention by each independent "source" vector all at once. That is, attention for source vectors at {1st, 2nd, ...} positions can be calculated in parallel. This "all at once" bundles (*, target_length) to (source_length, target_length).

For (x_block, x_block), both "source" and "target" (of attention) are source-side sentences of translation, i.e., self-attention in source-side sentences.
For (y_in_block, x_block), both "source" (of attention) are target-side sentences of translation and attends source-side of translation, i.e., the typical encoder-decoder attention. (In this case, terms for translation pairs and those for attention pairs are reversed.)
For self.make_history_mask(y_in_block), this masking uses intrinsically self-attention at target-side of translation.

The (i, j) value indicates whether the attention range of ith source vector contain jth target vector or not.

from attention_is_all_you_need.

machanic avatar machanic commented on July 18, 2024

your last sentence: The (i, j) value indicates whether the attention range of ith source vector contain jth target vector or not.

>>> source_block
array([[8, 6, 5, 9, 5],
       [9, 9, 5, 5, 8]])
>>> target_block
array([[5, 4, 4, 2, 5, 1, 4, 9],
       [9, 9, 7, 6, 9, 8, 1, 5]])
>>> mask = make_attention_mask(source_block, target_block)
>>> mask.shape
(2, 5, 8)
>>> mask[1]
array([[ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True]])

source_block have 2 arrays and target_block have 2 arrays( both shape=(batch_size, max_sentence_length) ) . What is the resulted mask, through shape = (2,5,8) can't explained that.

    def make_attention_mask(source_block, target_block):
        mask = (target_block[:, None, :] >= 0) * \
            (source_block[:, :, None] >= 0) 
        return mask

from attention_is_all_you_need.

soskek avatar soskek commented on July 18, 2024

Remember I explained with considering batch=1 case and ignoring axis=0 for simplicity. Discussion above is for a single sentence pair, with mask shape is (source_length, target_length). For batch processing, consider them with inserting batchsize into the head, i.e. (batchsize, source_length, target_length).

With it, the sentence is extended, for batch, as follows:

The value at index (b, i, j) (of whole mask array) indicates whether, for bth pair in the batch, the attention range of ith source vector contain jth target vector or not.

from attention_is_all_you_need.

machanic avatar machanic commented on July 18, 2024

@soskek I am sorry I didn't express clearly about my confusion, my confusion is :" There is no i-th source vector . eg. In my above example , there is only 2 source vectors(batch_size=2) , but there is i-th element of each of two source vectors" ,so the mask value at (b,i,j) indicates b-th pair batch, the i-th element in source vector contains j-th element in target vector? strange.

What is your conception "attention range" mean?
and what do you mean by "attention range of i th source vector"

I calculated another mask example as follows:

>>> source_block
array([[  2,   9,  -2,  -1,  -2],
       [ -1,   4,   4, -10,  -1]])
>>> target_block
array([[  7, -15, -12],
       [ -4,  -1, -21]])
>>> mask
array([[[ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False]],

       [[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False]]])
>>>

from attention_is_all_you_need.

Related Issues (6)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.