Comments (13)
@KerfuffleV2 Actually, I manage to understand how it works.
Here is my understanding:
I think it is the rolling-buffer-cache that implements this mechanism. rolling-buffer-cache maintains a layer-wise cache_v and cache_k whose size is sliding-window. When new tokens arrives, it will expel previous tokens in the buffer.
Here is how the cache_v and cache_k are initialized:
Lines 122 to 135 in 147c4e6
Here is how we get layer-wise cache and send it to each layer for computing
Lines 226 to 228 in 147c4e6
As for my question,
Do the rest layers only conduct attention mechanism on token 11 to 15?
Well, this is a stupid question, since the sliding window doesn't work in this way. As for the sliding window mechanism, I misunderstood it before and though there should be an attention mask that each layer had different W windows. But the truth is as @timlacroix said:
The "sliding" mechanism of the sliding window has to be understood over the sequence length. Each layer attends to the previous W tokens at that depth.
The sliding window mechanism restricts each token to only attend to other tokens within a fixed-size window W. However, the propagation of information through the network does not solely rely on the size of the attention window; it also involves the stacking of multiple attention layers, more like an indirectly access.
For example, we have a sequence of tokens [A, B, C, D, E, F, G, H], and let's say your sliding window (W) is 3 tokens wide
The output of Layer 1:
Token
Token
Token
Layer 2:
when token
This means token A in layer 2 has a "reach" that extends to token E, even though it can directly attend to only [A, B, C].
from mistral-src.
@NormXU Hey, sorry for a bit of necroposting. I'm also trying to understand this caching idea and I think you're right about the general idea but wrong about the details. If we look at the illustration from Mistral paper:
We can see here that token at the position i attends to the maximum number of tokens i-w:i from the previous layer (each token only sees previous layer's tokens on the left side). At minimum (if i=0), it's gonna attend to only one token at position i=0 of previous layer. If we follow your example, it would mean:
The output of Layer 1 (A', B'.... are the corresponding hidden states):
A' integrates information from only A.
B' integrates from A, B.
C' integrates from A, B, C
D' integrates from B, C, D (we reach the window size of 3) and so on.
And because of the stacked nature of decoder layers, tokens can attend to other tokens outside window size.
I might absolutely be wrong, but that's how it looks from the Mistral paper.
from mistral-src.
@defdet Firstly, I want you to know that I really appreciate your super quick and detailed reply!!!
Could you clarify what the query and keys would be in this example? I was under the impression that, for layer 2, Q would be tokens 5-9, and the Key would be tokens 1-5 (AKA layer 2 attends to tokens in layer 1).
Hey, Bro, I have seen the implementation of mistral from github and I try to help you understand the process. First of all you have to understand that this model (mistral) has multi-layers(actually 32 layers) and the role of each layer of cache: mistral each layer has a cache to store the input value of this layer (actually the output value from the previous layer) that was input by the previous fragment.
After understanding what I said above, consider the example in the paper illustration:
The actual expression of the second layers in the figure is that when the input is to the second fragment token 5-9, the query is actually from the current input token 5-9, while the key and value are converted to 1-9 by cache splicing token 1-5 of the previous fragment. Due to the slide window size set (4 for example), each token can only see the previous 4 furthest tokens. At the same time, in the attention calculation process of this layer, token 9 in token 5-9 in the same fragment can see the information of token 5 at the farthest. In this way, layer 2 token 9 indirectly sees layer1 token 1 information through layer 2 token 5. And so on, assume that the slide window size is W, and assume that the token entered in the position T on this layer can see the token from the position T-W on the previous layer.
from mistral-src.
The "sliding" mechanism of the sliding window has to be understood over the sequence length.
Each layer attends to the previous W tokens at that depth.
In your example, at any layer, when decoding token number 17, you'll attend to tokens 13 to 17.
When decoding the next token, number 18, you'll attend to tokens 14 to 18.
No idea if that clears anything up ?
from mistral-src.
@timlacroix Thank you for your reply.
According to the README,
At each attention layer, information can move forward by W tokens at most: after two attention layers, information can move forward by 2W tokens, etc. For instance in a sequence of length 16K and a sliding window of 4K, after 4 layers, information has propagated to the full sequence length.
This is where I am really confused. If my understanding is correct, the attention mechanism here is a "layer-wise" approach, which differs from the common one we are familiar with, since a sequence of 16k length can process through 4 layers with a sliding window of 4k.
from mistral-src.
@NormXU Sorry to bug you, but did you actually manage to figure out the answer to this or did you just eventually give up on ever receiving an answer?
from mistral-src.
@NormXU Thank you so much for the in-depth explanation, it is very much appreciated! I know llama.cpp
has been looking to implement this but it didn't seem like the necessary information was available.
I'll share this in the discussion over there and I'm sure it will help a lot! edit: This is the discussion in case you wanted to take a look: ggerganov/llama.cpp#3581
from mistral-src.
@defdet Thank you for pointing this out. I think you are right. All tokens only consider information from their preceding tokens, rather than having a bi-directional attention. This makes sense to me :)
from mistral-src.
Sorry for reviving this thread again haha.
I'm slightly confused on the inputs to each layer. Let's say your context vector length is 10, and each multi-head-attention layer has an input size of 5. Does layer 1 read tokens 1-5, and layer 2 takes in tokens 6-10? and each token in layer 2 attends to the tokens in layer 1, as per the diagram in the Mistral paper? @NormXU @defdet
from mistral-src.
@rshah918, you are correct about the whole idea, but I think you're a bit wrong about the details. Attention windows should be overlapping, as shown in the diagram above. In your example it would mean (first token is number 1, not 0) first layer indeed reads tokens 1-5, but the second one reads tokens 5-9, not 6-10. This detail is actually really important, because token 5 has information about tokens 4-1. If we don't include him in the calculations, upper layers (2 and above) are gonna have no information about tokens 4-1. Also, I'm pretty sure we're talking about whole decoder layers (consisting of self-attention and MLP each), not only attention layers.
Come to think of it, attention windows are overlapping by only one token. Maybe it could be beneficial to make them overlap by 2 tokens?..
from mistral-src.
@defdet Firstly, I want you to know that I really appreciate your super quick and detailed reply!!!
Could you clarify what the query and keys would be in this example? I was under the impression that, for layer 2, Q would be tokens 5-9, and the Key would be tokens 1-5 (AKA layer 2 attends to tokens in layer 1).
from mistral-src.
@defdet Firstly, I want you to know that I really appreciate your super quick and detailed reply!!!
Could you clarify what the query and keys would be in this example? I was under the impression that, for layer 2, Q would be tokens 5-9, and the Key would be tokens 1-5 (AKA layer 2 attends to tokens in layer 1).Hey, Bro, I have seen the implementation of mistral from github and I try to help you understand the process. First of all you have to understand that this model (mistral) has multi-layers(actually 32 layers) and the role of each layer of cache: mistral each layer has a cache to store the input value of this layer (actually the output value from the previous layer) that was input by the previous fragment.
After understanding what I said above, consider the example in the paper illustration:
The actual expression of the second layers in the figure is that when the input is to the second fragment token 5-9, the query is actually from the current input token 5-9, while the key and value are converted to 1-9 by cache splicing token 1-5 of the previous fragment. Due to the slide window size set (4 for example), each token can only see the previous 4 furthest tokens. At the same time, in the attention calculation process of this layer, token 9 in token 5-9 in the same fragment can see the information of token 5 at the farthest. In this way, layer 2 token 9 indirectly sees layer1 token 1 information through layer 2 token 5. And so on, assume that the slide window size is W, and assume that the token entered in the position T on this layer can see the token from the position T-W on the previous layer.
Actually, I have also looked at their implemented source code, and I'm curious that SWA cannot be applied to training.
from mistral-src.
It doesn't matter anyway lol. Mistral v2 doesn't have SWA.
from mistral-src.
Related Issues (20)
- "evaluation pipeline" public?
- [Mistral 7B mistral-7b-instruct-v0.1.Q8_0.gguf] Wrong text "quoted" while presented as real HOT 3
- Friendly Reminder while Generating the output
- Evaluation Pipeline
- Mistral's tokenizer is not optimized
- [MISTRAL AI ERROR] Mistral AI responding with Unexpected role RoleEnum.tool error HOT 2
- CUDA EXTENSION NOT INSTALLED nvcr.io/nvidia/pytorch:22.12-py3 HOT 1
- Training code HOT 1
- Question about Mixtral MLP section
- Missing the params.json HOT 2
- Fine Tuning Mistral 7b HOT 3
- I am unable to build the vLLM Container HOT 4
- Not completing answer
- Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW HOT 1
- How to use a prompt for text analysis? HOT 6
- JSON response format failing to retrieve clean JSON
- Q: Why rotary embedding applied only to queries and keys?
- Did Mistral-7B-Instruct-v0.2 use Sliding Window Attention (SWA)?
- AI generates responses or conversation without any human input
- PAD token missing ? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mistral-src.