Comments (1)
I actually encountered a similar scenario.
The standard Huggingface bert-base-cased model trained with 16 bit mixed precision (using pytorch-lightning), a vocab size of 100K and a seq len of 1024 uses around 34GB of Memory with a batch size=8 (on Nvidia A100). If I switch to full 32 bit precision the RAM usage almost doubles to 67GB of Memory which is expected.
However, if I use the x-transformers "Bert-like" implementation (mimicking the Huggingface config):
self.bert = TransformerWrapper(
num_tokens=100_000,
max_seq_len=1024,
emb_dropout=0.1,
tie_embedding=True,
attn_layers=Encoder(
dim=768,
depth=12,
heads=12,
attn_flash=False,
layer_dropout=0.1, # stochastic depth - dropout entire layer
attn_dropout=0.1, # dropout post-attention
ff_dropout=0.1, # feedforward dropout
use_abs_pos_emb=True,
),
)
the memory usage does not change if I switch between 16-mixed and 32 precision. The overall usage (same batch size and hardware) remains at a constant 52GB, which is substantially higher than the HF model with 34GB.
Why does the precision setting of the lightning trainer not affect the x-transformers implementation?
I would love to use the x-transformer implementation due to the large amount of new features. However, I am wondering where these significant GPU RAM differences come from? And why does torch.autocast, which I think is used by lightning under the hood show no effect?
from x-transformers.
Related Issues (20)
- kv cache breaks generation HOT 5
- Question: How to load model trained on earlier version of x-transformers HOT 3
- Enhancement: Multi Input/Output transformers HOT 1
- XL-recurrence with RotaryEmbedding and mems not working correctly. HOT 34
- Removed biases breaks pre-trained models HOT 5
- Seq len missing in rotary embedding HOT 3
- Adding memmask to ContinuousTransformerWrapper HOT 3
- attn_num_mem_kv > 0 and attn_one_kv_head = True error HOT 8
- Question: How to implement rel_pos_bias in cross_attention? HOT 11
- How to build optimizer HOT 9
- [Minor; noob question] Uniform distribution instead of normal
- RotaryEmbedding XPOS doesn't work with mems HOT 5
- `layer_mem` is unbound (when called from `ContinuousTransformerWrapper`) HOT 6
- Generation for PaLI?
- Confusion about image->caption example HOT 1
- How can I add custom attention masks to a Decoder? HOT 3
- Question: rotary embeddings and bad length extrapolation HOT 1
- [Bug] XL-recurrence with AlibiPositionalBias and mems not working correctly HOT 17
- [Question] very small attention scores HOT 7
- Was it a clerical error ? ScaleNorm.g init form dim ** -0.5. I think it should be dim ** 0.5 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from x-transformers.