Comments (5)
Thanks, this fixes the issue and is better than what I would have PR'd. Sorry that I dumped two bad testcases and then went to sleep.
@ad8e as for your following offer, it is ok
To clarify, do you mean it is ok to do it, or "it is ok" as in it is not necessary?
have a great new years Kevin
You too!
from x-transformers.
Thanks, this fixes the issue and is better than what I would have PR'd. Sorry that I dumped two bad testcases and then went to sleep.
@ad8e as for your following offer, it is ok
To clarify, do you mean it is ok to do it, or "it is ok" as in it is not necessary?
have a great new years Kevin
You too!
it isn't necessary, not for this lib
from x-transformers.
@ad8e hey Kevin
thanks for reporting
i quickly checked on a test script and it seems to be fine
import torch
from x_transformers import (
TransformerWrapper,
Decoder,
AutoregressiveWrapper
)
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 8,
depth = 1,
heads = 4
)
)
model = AutoregressiveWrapper(model)
prompts = torch.zeros((1, 1))
generated = model.generate(
prompts,
seq_len = 100,
temperature = 0.,
cache_kv = False
)
kv_cache_generated = model.generate(
prompts,
seq_len = 100,
temperature = 0.,
cache_kv = True
)
assert torch.allclose(generated, kv_cache_generated)
could you modify the script so that it breaks? perhaps you are using some hyperparameter that is incompatible with kv cache (would be good to put in a patch if so into can_cache_kv
logic)
logits
and logits2
in your code have different shapes. you need to compare logits
and logits2[:, -1:]
from x-transformers.
@ad8e as for your following offer, it is ok, as the library is model architecture specific. what you mention is all training related
from x-transformers.
@ad8e ah, got to the bottom of it Kevin
so it turns out the default (absolute positional embedding) is not kv cache friendly once you exceed the maximum sequence length (context window). however, it should still work when decoding from 1st token to the max context window size
i added an assert to prevent this, but also defaulted the enwik8 training script to use rotary positions, which is the preferred positional embeddings these days (llama), and kv cache friendly when exceeding context length.
ok, back to the holidays; have a great new years Kevin
from x-transformers.
Related Issues (20)
- Seq len missing in rotary embedding HOT 3
- Adding memmask to ContinuousTransformerWrapper HOT 3
- attn_num_mem_kv > 0 and attn_one_kv_head = True error HOT 8
- Question: How to implement rel_pos_bias in cross_attention? HOT 13
- How to build optimizer HOT 9
- [Minor; noob question] Uniform distribution instead of normal
- RotaryEmbedding XPOS doesn't work with mems HOT 5
- `layer_mem` is unbound (when called from `ContinuousTransformerWrapper`) HOT 6
- Generation for PaLI?
- Confusion about image->caption example HOT 1
- How can I add custom attention masks to a Decoder? HOT 3
- Question: rotary embeddings and bad length extrapolation HOT 1
- [Bug] XL-recurrence with AlibiPositionalBias and mems not working correctly HOT 17
- [Question] very small attention scores HOT 7
- Was it a clerical error ? ScaleNorm.g init form dim ** -0.5. I think it should be dim ** 0.5 HOT 1
- [Bug] Error when `rotary_pos_emb` set to True in cross attention HOT 3
- Question: problem with xval implementation HOT 5
- Correct interaction between CLS token and RoPE HOT 5
- RoPE inconsistency (2-dim subspaces choice)
- Sinusoidal embedding order choice different from original definition HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from x-transformers.