Coder Social home page Coder Social logo

Comments (7)

janeyx99 avatar janeyx99 commented on June 9, 2024 1

P sure this is caused by this line https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py#L552.

To confirm, @ad8e you can likely repro this with just a

optim = torch.optim.AdamW(model.parameters(), lr=0.0, capturable=True)
...
optim.step()

The real solution is to allow foreach_div to support Scalar as the first argument, but I'm not sure how hard that is cc @crcrpar. It feels like we should be able to just add an overload. Regarding priority, I'm not sure this is high pri. How likely is this use case? Is there a real use case for having lr be 0?

from pytorch.

wanchaol avatar wanchaol commented on June 9, 2024 1

Is this actually related to DTensor or this is more about torch.compile + optimizer? Based on the analysis above, I think if we just use normal torch.Tensor and torch.compile, set the lr=0.0, we should still repro the issue?

from pytorch.

ad8e avatar ad8e commented on June 9, 2024 1

The underlying bug is not in DTensor; it's in the optimizer. It's only that DTensor exposes this code path in the optimizer.

Normal torch.Tensor and torch.compile with lr=0.0 doesn't hit it; it's the capturable argument that Jane mentioned which is the key.

from pytorch.

xmfan avatar xmfan commented on June 9, 2024

@bdhirsh is this related to the torchtitan NaN loss you were talking about?

from pytorch.

xmfan avatar xmfan commented on June 9, 2024

@ad8e Does the NaN repro with single gpu?

from pytorch.

ad8e avatar ad8e commented on June 9, 2024

DTensor doesn't work when I change TP mesh size from 2 to 1: I receive

[rank1]: Traceback (most recent call last):
[rank1]:   File "/clusterstorage/workspace/kevin/nandtensor.py", line 99, in <module>
[rank1]:     model_tp = parallelize_module(model, tp_mesh, parallelize_plan=layer_plan)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/tensor/parallel/api.py", line 82, in parallelize_module
[rank1]:     random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/random.py", line 345, in _manual_seed
[rank1]:     tensor_parallel_rank = tp_mesh.get_local_rank()
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/device_mesh.py", line 502, in get_local_rank
[rank1]:     mesh_dim_group = not_none(self.get_group(mesh_dim))
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/device_mesh.py", line 411, in get_group
[rank1]:     _find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2])
[rank1]: IndexError: list index out of range

which means the process group isn't being created when the dim size is 1. So I cannot test if the NaN would appear or not with single GPU.

If I remove the DTensor, like so:

# model_tp = parallelize_module(model, tp_mesh, parallelize_plan=layer_plan)
model_tp = model
...
# gas_loss = gas_loss.full_tensor() # commented out

Then no NaNs appear. So the NaN only appears with DTensor.

It's not high priority for me because DTensor TP is currently useless due to low performance, so I don't use it anywhere. If DTensor actually mattered (above 70B scale, or if it finally gets comm/comp overlap working), then 0 LR would affect linear decay/warmup, in which case LR=0.0 is common at the endpoints, but avoidable. Another use case would be re-baking the AdamW second moment, which is necessary for resuming from a saved checkpoint without optimizer states, which is useful for saving disk space. This can be done using a very low LR instead of 0.0.

If anyone else cared about DTensor, they would be able to spot the NaN issue and work around it in both cases, since it is not a silent failure.

from pytorch.

ad8e avatar ad8e commented on June 9, 2024

I tried Jane's testcase, by taking the original DTensor TP=2 example, and making these modifications:

opt = AdamW(...
    capturable=True, # this is new
)
...
# opt.step = torch.compile(opt.step) # this is removed

The NaNs appear. So her diagnosis is correct.

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.