Coder Social home page Coder Social logo

Comments (13)

NormXU avatar NormXU commented on May 20, 2024 5

@KerfuffleV2 Actually, I manage to understand how it works.

Here is my understanding:

I think it is the rolling-buffer-cache that implements this mechanism. rolling-buffer-cache maintains a layer-wise cache_v and cache_k whose size is sliding-window. When new tokens arrives, it will expel previous tokens in the buffer.

Here is how the cache_v and cache_k are initialized:

self.cache_k = torch.empty((
n_layers,
max_batch_size,
sliding_window,
n_kv_heads,
head_dim
))
self.cache_v = torch.empty((
n_layers,
max_batch_size,
sliding_window,
n_kv_heads,
head_dim
))

Here is how we get layer-wise cache and send it to each layer for computing

for layer_id, layer in enumerate(self.layers):
cache_view = None if cache is None else cache.get_view(layer_id, input_metadata)
h = layer(h, freqs_cis, cache_view)

As for my question,

Do the rest layers only conduct attention mechanism on token 11 to 15?

Well, this is a stupid question, since the sliding window doesn't work in this way. As for the sliding window mechanism, I misunderstood it before and though there should be an attention mask that each layer had different W windows. But the truth is as @timlacroix said:

The "sliding" mechanism of the sliding window has to be understood over the sequence length. Each layer attends to the previous W tokens at that depth.

The sliding window mechanism restricts each token to only attend to other tokens within a fixed-size window W. However, the propagation of information through the network does not solely rely on the size of the attention window; it also involves the stacking of multiple attention layers, more like an indirectly access.

For example, we have a sequence of tokens [A, B, C, D, E, F, G, H], and let's say your sliding window (W) is 3 tokens wide

The output of Layer 1:

Token $\hat{A}$ integrates information from [A, B, C].
Token $\hat{B}$ integrates information from [A, B, C, D].
Token $\hat{C}$ integrates information from [A, B, C, D, E].

Layer 2:
when token $\hat{A}$ in the second layer attends to token $\hat{B}$, it's indirectly also getting information about token D, and when it attends to token $\hat{C}$, it's getting information about tokens D and E.

This means token A in layer 2 has a "reach" that extends to token E, even though it can directly attend to only [A, B, C].

from mistral-src.

defdet avatar defdet commented on May 20, 2024 1

@NormXU Hey, sorry for a bit of necroposting. I'm also trying to understand this caching idea and I think you're right about the general idea but wrong about the details. If we look at the illustration from Mistral paper:
illustrationl
We can see here that token at the position i attends to the maximum number of tokens i-w:i from the previous layer (each token only sees previous layer's tokens on the left side). At minimum (if i=0), it's gonna attend to only one token at position i=0 of previous layer. If we follow your example, it would mean:
The output of Layer 1 (A', B'.... are the corresponding hidden states):
A' integrates information from only A.
B' integrates from A, B.
C' integrates from A, B, C
D' integrates from B, C, D (we reach the window size of 3) and so on.
And because of the stacked nature of decoder layers, tokens can attend to other tokens outside window size.
I might absolutely be wrong, but that's how it looks from the Mistral paper.

from mistral-src.

MrYxJ avatar MrYxJ commented on May 20, 2024 1

@defdet Firstly, I want you to know that I really appreciate your super quick and detailed reply!!!

Could you clarify what the query and keys would be in this example? I was under the impression that, for layer 2, Q would be tokens 5-9, and the Key would be tokens 1-5 (AKA layer 2 attends to tokens in layer 1).

Hey, Bro, I have seen the implementation of mistral from github and I try to help you understand the process. First of all you have to understand that this model (mistral) has multi-layers(actually 32 layers) and the role of each layer of cache: mistral each layer has a cache to store the input value of this layer (actually the output value from the previous layer) that was input by the previous fragment.

After understanding what I said above, consider the example in the paper illustration:
292238819-d2f31d82-baa2-4452-8eba-2e38eb60e335

The actual expression of the second layers in the figure is that when the input is to the second fragment token 5-9, the query is actually from the current input token 5-9, while the key and value are converted to 1-9 by cache splicing token 1-5 of the previous fragment. Due to the slide window size set (4 for example), each token can only see the previous 4 furthest tokens. At the same time, in the attention calculation process of this layer, token 9 in token 5-9 in the same fragment can see the information of token 5 at the farthest. In this way, layer 2 token 9 indirectly sees layer1 token 1 information through layer 2 token 5. And so on, assume that the slide window size is W, and assume that the token entered in the position T on this layer can see the token from the position T-W on the previous layer.

from mistral-src.

timlacroix avatar timlacroix commented on May 20, 2024

The "sliding" mechanism of the sliding window has to be understood over the sequence length.
Each layer attends to the previous W tokens at that depth.

In your example, at any layer, when decoding token number 17, you'll attend to tokens 13 to 17.
When decoding the next token, number 18, you'll attend to tokens 14 to 18.

No idea if that clears anything up ?

from mistral-src.

NormXU avatar NormXU commented on May 20, 2024

@timlacroix Thank you for your reply.
According to the README,

At each attention layer, information can move forward by W tokens at most: after two attention layers, information can move forward by 2W tokens, etc. For instance in a sequence of length 16K and a sliding window of 4K, after 4 layers, information has propagated to the full sequence length.

This is where I am really confused. If my understanding is correct, the attention mechanism here is a "layer-wise" approach, which differs from the common one we are familiar with, since a sequence of 16k length can process through 4 layers with a sliding window of 4k.

from mistral-src.

KerfuffleV2 avatar KerfuffleV2 commented on May 20, 2024

@NormXU Sorry to bug you, but did you actually manage to figure out the answer to this or did you just eventually give up on ever receiving an answer?

from mistral-src.

KerfuffleV2 avatar KerfuffleV2 commented on May 20, 2024

@NormXU Thank you so much for the in-depth explanation, it is very much appreciated! I know llama.cpp has been looking to implement this but it didn't seem like the necessary information was available.

I'll share this in the discussion over there and I'm sure it will help a lot! edit: This is the discussion in case you wanted to take a look: ggerganov/llama.cpp#3581

from mistral-src.

NormXU avatar NormXU commented on May 20, 2024

@defdet Thank you for pointing this out. I think you are right. All tokens only consider information from their preceding tokens, rather than having a bi-directional attention. This makes sense to me :)

from mistral-src.

rshah918 avatar rshah918 commented on May 20, 2024

Sorry for reviving this thread again haha.

I'm slightly confused on the inputs to each layer. Let's say your context vector length is 10, and each multi-head-attention layer has an input size of 5. Does layer 1 read tokens 1-5, and layer 2 takes in tokens 6-10? and each token in layer 2 attends to the tokens in layer 1, as per the diagram in the Mistral paper? @NormXU @defdet

from mistral-src.

defdet avatar defdet commented on May 20, 2024

@rshah918, you are correct about the whole idea, but I think you're a bit wrong about the details. Attention windows should be overlapping, as shown in the diagram above. In your example it would mean (first token is number 1, not 0) first layer indeed reads tokens 1-5, but the second one reads tokens 5-9, not 6-10. This detail is actually really important, because token 5 has information about tokens 4-1. If we don't include him in the calculations, upper layers (2 and above) are gonna have no information about tokens 4-1. Also, I'm pretty sure we're talking about whole decoder layers (consisting of self-attention and MLP each), not only attention layers.
Come to think of it, attention windows are overlapping by only one token. Maybe it could be beneficial to make them overlap by 2 tokens?..

from mistral-src.

rshah918 avatar rshah918 commented on May 20, 2024

@defdet Firstly, I want you to know that I really appreciate your super quick and detailed reply!!!

Could you clarify what the query and keys would be in this example? I was under the impression that, for layer 2, Q would be tokens 5-9, and the Key would be tokens 1-5 (AKA layer 2 attends to tokens in layer 1).

from mistral-src.

matrixssy avatar matrixssy commented on May 20, 2024

@defdet Firstly, I want you to know that I really appreciate your super quick and detailed reply!!!
Could you clarify what the query and keys would be in this example? I was under the impression that, for layer 2, Q would be tokens 5-9, and the Key would be tokens 1-5 (AKA layer 2 attends to tokens in layer 1).

Hey, Bro, I have seen the implementation of mistral from github and I try to help you understand the process. First of all you have to understand that this model (mistral) has multi-layers(actually 32 layers) and the role of each layer of cache: mistral each layer has a cache to store the input value of this layer (actually the output value from the previous layer) that was input by the previous fragment.

After understanding what I said above, consider the example in the paper illustration: 292238819-d2f31d82-baa2-4452-8eba-2e38eb60e335

The actual expression of the second layers in the figure is that when the input is to the second fragment token 5-9, the query is actually from the current input token 5-9, while the key and value are converted to 1-9 by cache splicing token 1-5 of the previous fragment. Due to the slide window size set (4 for example), each token can only see the previous 4 furthest tokens. At the same time, in the attention calculation process of this layer, token 9 in token 5-9 in the same fragment can see the information of token 5 at the farthest. In this way, layer 2 token 9 indirectly sees layer1 token 1 information through layer 2 token 5. And so on, assume that the slide window size is W, and assume that the token entered in the position T on this layer can see the token from the position T-W on the previous layer.

Actually, I have also looked at their implemented source code, and I'm curious that SWA cannot be applied to training.

from mistral-src.

defdet avatar defdet commented on May 20, 2024

It doesn't matter anyway lol. Mistral v2 doesn't have SWA.

from mistral-src.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.