Comments (4)
- bmt.load() 只在 rank=0 加载 statedict 到内存 (比如在你的情况下是 130B),切分后直接放在各卡的 GPU 参数中,总共只占 130B 内存。所以现在瓶颈在于 load 之前的模型。
- 如果使用 ModelCenter,CPM-Live 等直接使用 bmtrain 实现 transformer layer 的模型库,model(config) 构建时参数就已经被切分,所以用 8 卡时,每张卡只占用 130B/8 的内存,总共只占用 130B。所以现在困难在于 wrapper 之前的模型不受 bmtrain 控制(比如你的例子里,模型是 transformers 库构建的),在被 wrapper 之前每张卡是完整的模型。
- Wrapper 做的事情其实比较简单,就是将模型参数 nn.Parameter 替换为 bmt.DistributedParameter,将 transformer block 用 bmt.CheckpointBlock() 进行包装。被 bmt.DistributedParameter 包装的参数在 model(config) 的时候就会按卡进行切分,减少总内存占用。所以如果要在 model(config) 时就切分参数,只能修改 model(config) 的构建方法(目前还没有想到更好的办法),我想到的办法可能是你在 BloomModel 源码里625-632行 构建 transformer block 的时候,将
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
改为
self.word_embeddings = bmt.BMTrainModelWrapper(nn.Embedding(config.vocab_size, self.embed_dim))
self.word_embeddings_layernorm = bmt.BMTrainModelWrapper(LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon))
# Transformer blocks
self.h = nn.ModuleList([
bmt.CheckpointBlock(bmt.BMTrainModelWrapper(BloomBlock(config)))
for _ in range(config.num_hidden_layers)
])
# Final Layer Norm
self.ln_f = bmt.BMTrainModelWrapper(LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon))
再比如你如果要用 BloomModelForCausalLM,他除了 BloomModel 外还增加了一个线性层 (第 828行),这个线性层也需要通过 bmt.BMTrainModelWrapper 包装一下,你举一反三一下。
from bmtrain.
我的意思是能否可以提供一个工具基于BloomModel(config)的类,将其转换成bmtrain支持的BMT_BloomModel(config),而不用人为的手动去修改。类似BMTrainModelWrapper一样去递归的进行处理
另外
self.h = nn.ModuleList([
bmt.CheckpointBlock(bmt.BMTrainModelWrapper(BloomBlock(config)))
for _ in range(config.num_hidden_layers)
])
改成下面的形式,可以做计算和通信的overlap吧
self.h = bmt.TransformerBlockList([
bmt.CheckpointBlock(bmt.BMTrainModelWrapper(BloomBlock(config)))
for _ in range(config.num_hidden_layers)
])
from bmtrain.
谢谢,感觉按你说的方法处理,工作量也不大,这样也就可以快速的适配其他模型了。
还有如果是预训练场景集群的卡数较多的情况下,单纯的zero并行会不会存在通信的瓶颈?是不是做两层的数据并行更合理?假设有1000卡,先按每50卡一组,分为20组,每组处理不同的数据Batch,在每组内50卡再做zero并行,将所负责的Batch数据分为50个mini-batch,每个卡再处理对应的mini-batch。前向计算结束,先组内做梯度聚合,再做组间聚合得到总梯度,最后将总梯度再分发给每个组,每个组内再按zero的形式去更新参数。
from bmtrain.
更多的并行方式正在实现中
from bmtrain.
Related Issues (20)
- 安装成功,但import失败,bmtrain版本0.2.2 HOT 2
- 模型加载 HOT 1
- BMTrain setup without torch
- Adam offloading thread bugs
- bmt.load(model) -> Unexpected OOM
- Make Checkpointing Optional
- 安装BMTranin失败:nccl.obj : error LNK2001: XXXX HOT 1
- How to distribute weights to different GPUs? HOT 2
- TypeError: expected string or bytes-like object HOT 1
- 我们以后能否和spark-gpu一起配合使用,开发 java 、scala . c++ 版本的bmtrain HOT 1
- Error when pip install bmtrain HOT 2
- gather_reuslt存在潜在问题 HOT 1
- gather result存在潜在问题
- gather result存在潜在问题 HOT 1
- Support Tensor Parallel
- model中存在Linear(config.hidden_size, config.vocab_size, bias=False)时候,print_inspect(model, "*")会报错。 HOT 1
- [BUG] Signal killed caused by Adam Offload
- [Feature] does bmtrain support torch 2.0+ HOT 1
- [BUG] Tensor Parallel async_chunk=4 mismatch async_chunk=1 result when sequence length longer than 16K
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from bmtrain.