Comments (7)
P sure this is caused by this line https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py#L552.
To confirm, @ad8e you can likely repro this with just a
optim = torch.optim.AdamW(model.parameters(), lr=0.0, capturable=True)
...
optim.step()
The real solution is to allow foreach_div to support Scalar as the first argument, but I'm not sure how hard that is cc @crcrpar. It feels like we should be able to just add an overload. Regarding priority, I'm not sure this is high pri. How likely is this use case? Is there a real use case for having lr be 0?
from pytorch.
Is this actually related to DTensor or this is more about torch.compile + optimizer? Based on the analysis above, I think if we just use normal torch.Tensor and torch.compile, set the lr=0.0, we should still repro the issue?
from pytorch.
The underlying bug is not in DTensor; it's in the optimizer. It's only that DTensor exposes this code path in the optimizer.
Normal torch.Tensor and torch.compile with lr=0.0 doesn't hit it; it's the capturable
argument that Jane mentioned which is the key.
from pytorch.
@bdhirsh is this related to the torchtitan NaN loss you were talking about?
from pytorch.
@ad8e Does the NaN repro with single gpu?
from pytorch.
DTensor doesn't work when I change TP mesh size from 2 to 1: I receive
[rank1]: Traceback (most recent call last):
[rank1]: File "/clusterstorage/workspace/kevin/nandtensor.py", line 99, in <module>
[rank1]: model_tp = parallelize_module(model, tp_mesh, parallelize_plan=layer_plan)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/tensor/parallel/api.py", line 82, in parallelize_module
[rank1]: random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/random.py", line 345, in _manual_seed
[rank1]: tensor_parallel_rank = tp_mesh.get_local_rank()
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/device_mesh.py", line 502, in get_local_rank
[rank1]: mesh_dim_group = not_none(self.get_group(mesh_dim))
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/device_mesh.py", line 411, in get_group
[rank1]: _find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2])
[rank1]: IndexError: list index out of range
which means the process group isn't being created when the dim size is 1. So I cannot test if the NaN would appear or not with single GPU.
If I remove the DTensor, like so:
# model_tp = parallelize_module(model, tp_mesh, parallelize_plan=layer_plan)
model_tp = model
...
# gas_loss = gas_loss.full_tensor() # commented out
Then no NaNs appear. So the NaN only appears with DTensor.
It's not high priority for me because DTensor TP is currently useless due to low performance, so I don't use it anywhere. If DTensor actually mattered (above 70B scale, or if it finally gets comm/comp overlap working), then 0 LR would affect linear decay/warmup, in which case LR=0.0 is common at the endpoints, but avoidable. Another use case would be re-baking the AdamW second moment, which is necessary for resuming from a saved checkpoint without optimizer states, which is useful for saving disk space. This can be done using a very low LR instead of 0.0.
If anyone else cared about DTensor, they would be able to spot the NaN issue and work around it in both cases, since it is not a silent failure.
from pytorch.
I tried Jane's testcase, by taking the original DTensor TP=2 example, and making these modifications:
opt = AdamW(...
capturable=True, # this is new
)
...
# opt.step = torch.compile(opt.step) # this is removed
The NaNs appear. So her diagnosis is correct.
from pytorch.
Related Issues (20)
- DISABLED test_register_fsdp_forward_method (__main__.TestFullyShardCustomForwardMethod) HOT 1
- DISABLED test_register_fsdp_forward_method (__main__.TestFullyShardCustomForwardMethod) HOT 1
- DISABLED test_dtensor_op_db_inner_cpu_float32 (__main__.TestDTensorOpsCPU) HOT 1
- DISABLED test_vertical_pointwise_reduction_fusion_cuda (__main__.TestUnbackedSymintsCUDA) HOT 1
- LambdaLR has incorrect multiplicative behavior when using torch.tensor LR HOT 12
- TorchDynamo ONNX Export does not work as expected with masking (ScatterElements)
- [inductor][cpu]Background_Matting and pytorch_CycleGAN_and_pix2pix AMP multiple thread static/dynamic shape CPP/default wrapper performance regression HOT 1
- worse results by using the MPS backend, compared to the CPU HOT 2
- Compile with non-default mode + triton kernel fails HOT 1
- DISABLED test__int_mm_k_16_n_32_use_transpose_a_False_use_transpose_b_False_cuda (__main__.TestLinalgCUDA) HOT 1
- DISABLED test_non_contiguous_input_mm_plus_mm (__main__.TestMaxAutotune) HOT 2
- DISABLED test_dtensor_op_db_vstack_cpu_float32 (__main__.TestDTensorOpsCPU) HOT 2
- opcheck has dependency on expecttest, which is not a pytorch runtime dependency, leading to "module not found" error message
- running opcheck leads to `Fail to import hypothesis in common_utils, tests are not derandomized` print
- [docs] scaled_dot_product_attention is_causal description is misleading HOT 4
- Add warning messages to provide info about expected performance improvement using cuda for a specific model
- UserWarning:Plan failed with a cudnnException HOT 1
- [DCP] DCP does not support objects which are lazy initialized. HOT 3
- Bug: `torch.func.jacrev` fails with backend=`aot_eager` HOT 1
- UNSTABLE inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor_timm) HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.