Coder Social home page Coder Social logo

sangmichaelxie / doremi Goto Github PK

View Code? Open in Web Editor NEW
282.0 282.0 32.0 25 MB

Pytorch implementation of DoReMi, a method for optimizing the data mixture weights in language modeling datasets

Home Page: https://arxiv.org/abs/2305.10429

License: MIT License

Python 1.79% Shell 0.03% Makefile 0.01% CMake 0.32% Cuda 10.61% HTML 72.08% CSS 0.04% JavaScript 0.08% C++ 15.02% C 0.02% Batchfile 0.01% Dockerfile 0.01%
data-centric-machine-learning large-language-models nlp

doremi's People

Contributors

eltociear avatar sangmichaelxie avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

doremi's Issues

Speed decrease during training

We established the environment and preprocessed the data as per the provided instructions. However, while executing the command bash scripts/runs/run_pile_baseline120M.sh, we noticed a sudden reduction in speed after loading specific batches, for example, around 500 out of 200,000. The speed dropped from 2 iterations per second to 6 seconds per iteration, representing a more than 10-fold decrease in speed. This issue also occurs during the second stage when running
bash scripts/runs/run_pile_doremi120M.sh. Our setup comprises 4 A100 nodes with 80 GB of memory each. Do you have such problems during your training before, or do you have any insights into the potential reasons for this occurrence? Thanks!

Domain weights are mostly near one-hot

Hi Michael,

Thanks for the amazing work and releasing the code base! I am running doremi in my own setting and noticed that most of the times, the domain weights are nearly one-hot. As a result, most proxy model updates are dominated by the domain that has the worst excess loss. I wonder if it is expected? In this case, the proxy model reduces the overall average loss much slower than training with original domain proportion.

Thanks!

Cuda version problem

I'm trying to run this script on aws sagemaker

cd doremi && bash scripts/setup_flash.sh

it throws an error:

The detected CUDA version (12.1) mismatches the version that was used to compile
      PyTorch (11.7). Please make sure to use the same CUDA versions.

about loss

help please.
1、why excess loss do not follow the paper: max(excess-loss, 0)。

Questions about the loss used for optimizing the proxy model

@sangmichaelxie It seems that the loss used for optimizing the proxy model in the code is different from the one described in the paper.

loss = (pertoken_loss * curr_domain_weights.detach()).sum() / normalizer

In the code, you directly use the proxy model's own loss here to optimize. But in the paper, the loss seems to be the minimax loss which uses the excess loss. So which one should I conform? Or there is something wrong with my understanding. Thanks.

image

question about only updating the domain weights on process 0

Hi Michael,

Thanks for releasing this code base and all the amazing work you have done! I'm learning about DoReMi and have a question: I noticed that the domain weights are updated only on the process 0, so how do other processes get the new weights when compute the loss and update the proxy model?

Thanks!

Cannot reproduce the results shown in Github repo with the 120M reference model on A800 (8*80G).

Hi, thanks for sharing this code base.

After I run the script of bash scripts/run_pile.sh, I obtain the following results:
image

The generated domain reweights have slightly differences from the released domain reweights:
image

Since I am in Mainland China, I download the tokenizer manually. But I cannot find togethercomputer/RedPajama-INCITE-Base-7B-v0.1 in Huggingface, I use the tokenizer togethercomputer/RedPajama-INCITE-Base-7B. I think they are the same.

Edge Case Discussion

Thank for the wonderful work !
I have a question about Equation 1 in the paper. If the proxy model's parameters become (or get initialized) the same as the reference model, then training would converge and the learned domain weights might not be meaningful (since they do not matter now as loss = 0). Can you clarify @sangmichaelxie ?

Request for Redpajama Dataset Weights

Thank you for your excellent work! I noticed that the repository includes a baseline build for the Redpajama dataset, but I couldn't find the domain weights that were derived from Redpajama. Could you please consider publishing some of the weights obtained from this dataset?

Question about Flash-attention version.

Hi, thanks for sharing this code base.

I am wondering about the Flash-attention version used in this repo, since the latest version of Flash-attention has a mismatch with transformers==4.27.2.

How many rounds do we need to converge domain weights on The Pile?

Thanks for your awesome work! I noticed that there is a optimized weights called configs/pile_doremi_r1_120M_ref:pile_baseline_50kvocab_nopack_120M.json as shown in README. Can we consider this domain weights as the result of the first round of doremi?

By comparing with the results shown in the paper, we can find that these optimized weights are far from the one reported in the paper. For example, the domain weight of Pile-CC is 0.13788709, but the result in the paper is 0.6057. And if 0.13788709 is the result of the first round, we can conclude that the increase domain weight in Pile-CC is about 0.028861896. Then we can estimate that it would take approximately 21 rounds to converge to 0.6057.

P.S. Thanks for your reply in this issue: #11. I also want to ask how many rounds do we need to converge the domain weights on RedPajama?

Thanks.

step 1 baseline_280M loss large

280M baseline model loss is hovering around 5, with all training hyperparameters set default values.
The preprocess file sampler is set to 10w
image

image

Question about optimized weights in the paper

Hi! I tried to directly train the main model from the optimized weights in the paper (pile_doremi_280M_256kvocab_paper.json), and I got significantly lower results (for example, in 20k checkpoint, this model only has 4.8 acc while the baseline is 5.6) than the baseline model. Is this because this weight is found by the 280M proxy model? Shouldn't it be generalizable across models?

BTW, I just found the pile_doremi_280M_256kvocab_paper.json and pile_doremi_r1_120M_ref:pile_baseline_50kvocab_nopack_120M.json has totally different trends in weights. Does it show that there may exist many optimal weight combinations for the whole dataset? Thank you so much.

easy HF dataset doremi?

Is there a data set compatible with HF I may use?

dataset = load_dataset("c4", "en", streaming=True, split="train").with_format("torch")
remove_columns = ["text", "timestamp", "url"]
but instead have

dataset = load_dataset("doremi", "en", streaming=True, split="train").with_format("torch")
remove_columns = ["text", "timestamp", "url"]
thus automatically using the doremi weights?

ModuleNotFoundError: No module named 'flash_attn.models.falcon'

I ran the bash scripts/setup_flash.sh without error (but it cost just a few minute)

image

But I got a wrong message when I run the bash scripts/run_pile.sh

Traceback (most recent call last):
  File "doremi/train.py", line 56, in <module>
    import doremi.models as doremi_models
  File "/storage/home/lanzhenzhongLab/zhaoyu/doremi/doremi/models.py", line 8, in <module>
    from flash_attn.models.gpt import GPTLMHeadModel as GPTLMHeadModelFlash
  File "/storage/home/lanzhenzhongLab/zhaoyu/.conda/envs/zy_doremi/lib/python3.8/site-packages/flash_attn-2.0.4-py3.8-linux-x86_64.egg/flash_attn/models/gpt.py", line 31, in <module>
    from flash_attn.models.falcon import remap_state_dict_hf_falcon
ModuleNotFoundError: No module named 'flash_attn.models.falcon'

what wrong with this?

also,I found something wrong with bash scripts/run_preprocess_pile.sh until I update the packages datasets to 2.15.0 , but version in setup.py is 2.10.1.
Is something wrong in my operate?

Question about 8B model architecture

Thanks for your interesting work!

Could you please provide the configuration of the 8B model?

How do you generate the proxy model? Is it generated based on 8B? Or just a smaller model with arbitrary configuration (e.g., layer num)?

Multi-nodes support

Hi,

Thanks for sharing this opensource implementation. I am wondering does the current implementation support to train a larger reference/proxy model using multi-nodes?

Thanks

Question about model initialization

Does reference model, proxy model and main model have to be initialized with the same method? When continue pretraining LlaMA2 with doremi, the weights of the main model are initialized from the meta checkpoint. But for the reference model and procy model, there are not such checkpoints. Instead, these models are initialized with other methods(e.g. Xavier initialization). In this scenario, will the doamin weights of the procy model still improve the performance of the main model?

loss computation wrong?

It seems that the loss implementation (https://github.com/sangmichaelxie/doremi/blob/main/doremi/trainer.py#L360) is not exactly the same as the loss in the paper. In the implementation, the normalizer is Σ_i α_i Σ_{x\in Dᵢ} |x| but should just be Σ_{x\in Dᵢ} |x| for samples from i-th domain. Any comments on this observation?

Here is the code that implements the loss in the paper. It seems you get smoother domain weights using the following implementation.

# compute the rescaled loss, divide by domain weights
train_domain_weights = self.read_weights().to(pertoken_loss.device)
# if doing non-uniform sampling, normalize by inverse sampling weight
train_domain_weights = train_domain_weights / self.sampling_weights.to(train_domain_weights.device)
train_domain_weights = train_domain_weights / train_domain_weights.sum()

# (#domains,) total number of tokens amongst samples from each domain
perdomain_num_tokens = []
for domain_id in range(len(train_domain_weights)):
    domain_mask = (inputs['domain_ids'] == domain_id)
    if domain_mask.sum() > 0:
        num_tokens = token_mask[domain_mask].sum()
    else:
        num_tokens = torch.tensor(0., device=token_mask.device)
    perdomain_num_tokens.append(num_tokens)
perdomain_num_tokens = torch.stack(perdomain_num_tokens)

## sync between procs `perdomain_num_tokens` since different procs 
# might process micro-batch samples from the same domain.
dist.all_reduce(perdomain_num_tokens, op=torch.distributed.ReduceOp.SUM)
# scale by world size because DDP averages gradients
perdomain_num_tokens = perdomain_num_tokens / self.args.world_size

# avoid division by zero
perdomain_num_tokens[torch.where(perdomain_num_tokens==0)] = 1.
# (#domains,) equivalent to αᵢ / Σ_{x\in D_i} |x|
perdomain_coeff = train_domain_weights/perdomain_num_tokens
# (bsz, seq_len-1)
coeff = perdomain_coeff[inputs['domain_ids']].unsqueeze(-1) * token_mask
loss = (pertoken_loss * coeff.detach()).sum()

List of pinned requirements / Dockerfile?

I'm struggling to replicate DoReMi traning (weird errors, probably due to some incompatibility between dependencies).
Is there a list of pinned requirements (in particular what should be the version of flash_attn?). or, preferably, a Dockerfile for easy reproduction.

question about domain weights initialization value in paper fingure 8

Thx for the amazing paper and open codebase! I have one question about paper fingure 8.
The domain weights initialization value in figure 8 (step 0) seems not equal as the Algorithm 1 Initialize domain weights α0 = 1/k.
And domain weights initialization in figure 8a is also not equal with figure 8b.
So what is the domain weights initialization strategy in figure 8?

image
image

Adding a license

Please add a license to encourage source collaboration and broaden impact. An MIT license is suggested

Questions about directly applying the weights from paper or the repo to train main model

Thanks for your solid work first!
But I am wondering whether the optimized domain weights only significantly related with the tokenizer.
If I use the same tokenizer and the domain weights just as that the repo released to train a main model, but with some different in other training configs, such as training steps, learning rate, global batch size and so on.
Can this work? Or the training procedure must be entirely the same as the proxy and reference model?
@sangmichaelxie

AssertionError:assert q.dtype in [torch.float16, torch.bfloat16]

When I finished the training process of the doremi reference model, I want to evaluate it on the downstream tasks, but I get this error:
Traceback (most recent call last):
File "/home/wth/My_codes/doremi/doremi/train.py", line 409, in
fwd_output = self._forward(
File "/home/wth/My_codes/doremi/doremi/models.py", line 59, in _forward
hidden_states = self.transformer(input_ids, position_ids=position_ids,
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
main()
File "/home/wth/My_codes/doremi/doremi/train.py", line 395, in main
return forward_call(*args, **kwargs)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/models/gpt.py", line 373, in forward
downstream_metrics = trainer.evaluate_fewshot(
File "/home/wth/My_codes/doremi/doremi/trainer.py", line 670, in evaluate_fewshot
hidden_states, residual = layer(hidden_states, residual,
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
gen_tokens = model.generate(
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/utils/generation.py", line 166, in generate
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/utils/generation.py", line 115, in decode
logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/block.py", line 148, in forward
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wth/My_codes/doremi/doremi/models.py", line 99, in forward
fwd_output = self._forward(
File "/home/wth/My_codes/doremi/doremi/models.py", line 59, in _forward
hidden_states = self.transformer(input_ids, position_ids=position_ids,
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 517, in forward
return forward_call(*args, **kwargs)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/models/gpt.py", line 373, in forward
context = self.inner_cross_attn(q, kv, causal=causal)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
hidden_states, residual = layer(hidden_states, residual,
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 124, in forward
return forward_call(*args, **kwargs)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/block.py", line 148, in forward
assert q.dtype in [torch.float16, torch.bfloat16]
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
AssertionError
return forward_call(*args, **kwargs)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 517, in forward
context = self.inner_cross_attn(q, kv, causal=causal)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wth/anaconda3/envs/doremi/lib/python3.10/site-packages/flash_attn-2.0.4-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 124, in forward
assert q.dtype in [torch.float16, torch.bfloat16]
AssertionError
May I ask how this problem can be solved?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.