Coder Social home page Coder Social logo

Comments (4)

Achazwl avatar Achazwl commented on September 18, 2024
  1. bmt.load() 只在 rank=0 加载 statedict 到内存 (比如在你的情况下是 130B),切分后直接放在各卡的 GPU 参数中,总共只占 130B 内存。所以现在瓶颈在于 load 之前的模型。
  2. 如果使用 ModelCenterCPM-Live 等直接使用 bmtrain 实现 transformer layer 的模型库,model(config) 构建时参数就已经被切分,所以用 8 卡时,每张卡只占用 130B/8 的内存,总共只占用 130B。所以现在困难在于 wrapper 之前的模型不受 bmtrain 控制(比如你的例子里,模型是 transformers 库构建的),在被 wrapper 之前每张卡是完整的模型。
  3. Wrapper 做的事情其实比较简单,就是将模型参数 nn.Parameter 替换为 bmt.DistributedParameter,将 transformer block 用 bmt.CheckpointBlock() 进行包装。被 bmt.DistributedParameter 包装的参数在 model(config) 的时候就会按卡进行切分,减少总内存占用。所以如果要在 model(config) 时就切分参数,只能修改 model(config) 的构建方法(目前还没有想到更好的办法),我想到的办法可能是你在 BloomModel 源码里625-632行 构建 transformer block 的时候,将
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

# Transformer blocks
self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])

# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

改为

self.word_embeddings = bmt.BMTrainModelWrapper(nn.Embedding(config.vocab_size, self.embed_dim))
self.word_embeddings_layernorm = bmt.BMTrainModelWrapper(LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon))

# Transformer blocks
self.h = nn.ModuleList([
  bmt.CheckpointBlock(bmt.BMTrainModelWrapper(BloomBlock(config)))
  for _ in range(config.num_hidden_layers)
])

# Final Layer Norm
self.ln_f = bmt.BMTrainModelWrapper(LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon))

再比如你如果要用 BloomModelForCausalLM,他除了 BloomModel 外还增加了一个线性层 (第 828行),这个线性层也需要通过 bmt.BMTrainModelWrapper 包装一下,你举一反三一下。

from bmtrain.

jinmin527 avatar jinmin527 commented on September 18, 2024

我的意思是能否可以提供一个工具基于BloomModel(config)的类,将其转换成bmtrain支持的BMT_BloomModel(config),而不用人为的手动去修改。类似BMTrainModelWrapper一样去递归的进行处理

另外
self.h = nn.ModuleList([
bmt.CheckpointBlock(bmt.BMTrainModelWrapper(BloomBlock(config)))
for _ in range(config.num_hidden_layers)
])
改成下面的形式,可以做计算和通信的overlap吧
self.h = bmt.TransformerBlockList([
bmt.CheckpointBlock(bmt.BMTrainModelWrapper(BloomBlock(config)))
for _ in range(config.num_hidden_layers)
])

from bmtrain.

jinmin527 avatar jinmin527 commented on September 18, 2024

谢谢,感觉按你说的方法处理,工作量也不大,这样也就可以快速的适配其他模型了。
还有如果是预训练场景集群的卡数较多的情况下,单纯的zero并行会不会存在通信的瓶颈?是不是做两层的数据并行更合理?假设有1000卡,先按每50卡一组,分为20组,每组处理不同的数据Batch,在每组内50卡再做zero并行,将所负责的Batch数据分为50个mini-batch,每个卡再处理对应的mini-batch。前向计算结束,先组内做梯度聚合,再做组间聚合得到总梯度,最后将总梯度再分发给每个组,每个组内再按zero的形式去更新参数。

from bmtrain.

Achazwl avatar Achazwl commented on September 18, 2024

更多的并行方式正在实现中

from bmtrain.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.