Comments (14)
@abdulfatir Since this might be a NCCL regression, I wonder if we can try to do the following:
- Get the NCCL versions for both envs, e.g. via
python -c "import torch;print(torch.cuda.nccl.version())"
- Run a script that only does all-reduce
Example with profiler:
# Init process group etc.
t = torch.empty((7347200,), device="cuda")
def fn():
dist.all_reduce(t)
benchmark_with_profiler(fn) # maybe change `active` to 3+ instead of 2 to get back-to-back all-reduces in the profile
Example with CUDA events for timing without profiler:
# Init process group etc.
t = torch.empty((7347200,), device="cuda")
def fn():
dist.all_reduce(t)
num_warmup, num_iters = 3, 10
for _ in range(num_warmup):
fn()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iters):
fn()
end.record()
torch.cuda.synchronize()
elapsed_time = start.elapsed_time(end)
time_per_all_reduce = elapsed_time / num_iters
print(f"time per all-reduce: {time_per_all_reduce:.5f}")
The latter might be simpler since then we do not need to manually inspect the profiles.
from pytorch.
I think something like the following may work:
import torch.distributed as dist
def benchmark_with_profiler(
benchmark_fn,
*benchmark_fn_args,
**benchmark_fn_kwargs,
) -> None:
torch._C._profiler._set_cuda_sync_enabled_val(False)
wait, warmup, active = 1, 1, 2
num_steps = wait + warmup + active
rank = dist.get_rank()
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=wait, warmup=warmup, active=active, repeat=1, skip_first=1
),
on_trace_ready=(
torch.profiler.tensorboard_trace_handler("./") if not rank else None
),
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=False,
) as prof:
for step_idx in range(1, num_steps + 1):
benchmark_fn(*benchmark_fn_args, **benchmark_fn_kwargs)
if rank is None or rank == 0:
prof.step()
def train_step(ddp_model, optimizer, loss_fn):
optimizer.zero_grad()
outputs = ddp_model(torch.randn(16, 10))
labels = torch.randn(16, 5).to(device_id)
loss_fn(outputs, labels).backward()
optimizer.step()
# Define `ddp_model`, `optimizer`, `loss_fn`
benchmark_with_profiler(train_step, ddp_model, optimizer, loss_fn)
This should just take the profiler trace on rank 0 and save it to some .json
file. If you could share the .json
file, then that would be great. (We would then view them in something like chrome://tracing/ to compare the difference between the envs.)
from pytorch.
Here are the files:
torch2.1.json
torch2.3.json
For completeness, here is the script I ran using torchrun --nproc-per-node=4 profile-ddp.py
:
profile-ddp.py.txt
from pytorch.
Thanks!
I wonder if there is a regression in NCCL across the two versions.
- torch2.1: all-reduces take 1.309 ms and 7.944 ms
- torch2.3: all-reduces take 1.952 ms and 13.987 ms
- All-reduce message sizes are (1054725 * 4) bytes and (7347200 * 4) bytes, respectively
For torch2.1, this equates to roughly
- 2 * (1054725 * 4 bytes) / (0.001309 seconds) / 1e9 = 6.446 GB/s bandwidth
- 2 * (7347200 * 4 bytes) / (0.007944 seconds) / 1e9 = 7.399 GB/s bandwidth
where the factor of 2 is from all-reduce requiring two passes around the ring. More precisely, there could be anN-1/N
factor forN=4
GPUs, but it does not matter that much. The point is that, the achieved bandwidth on these collectives is pretty low.
I am not as familiar with this part, but I think that it is possible to use https://github.com/NVIDIA/nccl-tests to test what kind of all-reduce bandwidth you should expect to see for your hardware setup.
In short, the current evidence points toward there being a NCCL regression in the all-reduce times for your model/setup, and your training is communication bound. This translates to slowdown in end-to-end training time.
cc: @kwen2501 @wconstab if you guys have any suggestions on how to further diagnose this
from pytorch.
torch 2.1
NCCL version: (2, 18, 1)
time per all-reduce rank=1: 0.00031
time per all-reduce rank=3: 0.00020
time per all-reduce rank=2: 0.00041
time per all-reduce rank=0: 8.05868
torch 2.2
NCCL version: (2, 19, 3)
time per all-reduce rank=2: 0.03922
time per all-reduce rank=1: 0.03787
time per all-reduce rank=3: 0.26604
time per all-reduce rank=0: 13.19803
torch 2.3
NCCL verison: (2, 20, 5)
time per all-reduce rank=3: 0.00057
time per all-reduce rank=1: 0.00051
time per all-reduce rank=2: 0.00051
time per all-reduce rank=0: 13.15645
Code:
import torch
import torch.distributed as dist
if __name__ == "__main__":
dist.init_process_group("nccl")
rank = dist.get_rank()
device_id = rank % torch.cuda.device_count()
t = torch.empty((7347200,), device=device_id)
def fn():
dist.all_reduce(t)
num_warmup, num_iters = 3, 10
for _ in range(num_warmup):
fn()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iters):
fn()
end.record()
torch.cuda.synchronize()
elapsed_time = start.elapsed_time(end)
time_per_all_reduce = elapsed_time / num_iters
print(f"time per all-reduce {rank=}: {time_per_all_reduce:.5f}")
dist.destroy_process_group()
from pytorch.
Regarding the micro benchmark, since you are calling torch.cuda.Event()
and torch.cuda.synchronize()
, the program must know which device to perform these operations on. Otherwise, these two operations would be performed on device 0 (the default one in CUDA's view). Obviously, the non-0 ranks do not have CUDA kernels on device 0. So they record almost 0 time.
You just need to add a line here:
dist.init_process_group("nccl")
rank = dist.get_rank()
device_id = rank % torch.cuda.device_count()
torch.cuda.set_device(device_id) <— add a line here
t = torch.empty((7347200,), device=device_id)
Then you would see roughly equal time:
NCCL version 2.19.3+cuda12.0
time per all-reduce rank=1: 5.80684
time per all-reduce rank=3: 5.81415
time per all-reduce rank=0: 5.83221
time per all-reduce rank=2: 5.83355
(I measure the time on a different hardware platform, so don't take the absolute number seriously.)
from pytorch.
Sure, here's the result:
torch 2.1
time per all-reduce rank=3: 7.89238
time per all-reduce rank=2: 7.87067
time per all-reduce rank=0: 7.88849
time per all-reduce rank=1: 7.89137
torch 2.3
time per all-reduce rank=3: 13.15246
time per all-reduce rank=2: 13.10884
time per all-reduce rank=0: 13.14847
time per all-reduce rank=1: 13.13393
from pytorch.
Let me type the commands here:
git clone --depth 1 --branch <tag> [email protected]:NVIDIA/nccl.git
git clone [email protected]:NVIDIA/nccl-tests.git
cd nccl && make -j src.build NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80" && cd ..
cd nccl-tests
make -j NCCL_HOME=../nccl/build
LD_LIBRARY_PATH=../nccl/build/lib:$LD_LIBRARY_PATH NCCL_DEBUG=WARN ./build/all_reduce_perf -t 4 -b 29388800 -e 29388800
You can do the above separately with v2.18.5-1
and v2.20.5-1
for the <tag>
field.
from pytorch.
Something that could help us a lot since you already have envs setup would be to get profiler traces from the two envs to compare. Would you be able to do that and share the traces?
from pytorch.
@awgu If you could tell me how to do that, I'll be happy to share the traces.
from pytorch.
And also regarding NCCL perf of different versions, I did some micro benchmarks too using https://github.com/NVIDIA/nccl-tests.
There seems to be no difference across versions, with the sizes you provided.
See https://gist.github.com/kwen2501/89c28b6d12045b45ccf7e33816af2713.
But I was testing them on 4 x A100 with NVLinks, instead of 4 x A10g, which does not rule out the chance of a platform-specific issue.
That said, another possibility may be that the regression comes from torch side.
from pytorch.
Thanks!
Are you familiar with running nccl-tests? Can you run it with NCCL 2.18.5 and 2.20.5 on your platform? Thanks!
from pytorch.
Thanks, I will run the NCCL tests but before that something else to share. I ran the same tests on two other machines.
8 x V100
torch 2.1
Original script at the top of this issue shows: ~1h40mins.
time per all-reduce rank=5: 0.44073
time per all-reduce rank=6: 0.44124
time per all-reduce rank=3: 0.44104
time per all-reduce rank=1: 0.44083
time per all-reduce rank=4: 0.44104
time per all-reduce rank=2: 0.44104
time per all-reduce rank=0: 0.44114
time per all-reduce rank=7: 0.44104
torch 2.3
Original script at the top of this issue shows: ~1h40mins.
time per all-reduce rank=3: 0.50012
time per all-reduce rank=6: 0.50033
time per all-reduce rank=2: 0.49992
time per all-reduce rank=5: 0.49951
time per all-reduce rank=7: 0.49930
time per all-reduce rank=1: 0.50022
time per all-reduce rank=4: 0.49295
time per all-reduce rank=0: 0.49132
8 x A100
torch2.1
Original script at the top of this issue shows: ~50mins.
time per all-reduce rank=1: 0.38461
time per all-reduce rank=0: 0.38410
time per all-reduce rank=3: 0.38472
time per all-reduce rank=2: 0.38420
time per all-reduce rank=4: 0.38513
time per all-reduce rank=7: 0.33761
time per all-reduce rank=6: 0.38533
time per all-reduce rank=5: 0.32881
torch2.3
Original script at the top of this issue shows: ~50mins.
time per all-reduce rank=4: 0.36547
time per all-reduce rank=2: 0.36598
time per all-reduce rank=5: 0.36577
time per all-reduce rank=1: 0.36588
time per all-reduce rank=7: 0.36649
time per all-reduce rank=0: 0.36618
time per all-reduce rank=6: 0.32758
time per all-reduce rank=3: 0.33905
All this points to the fact that there may be some issue in the 4 x A10G setup. However, the funny part is that I had first noticed slow training on my 8 x A100 machine for a larger codebase that relies on transformers
. I still observe slow training in my original code. I thought I had isolated the issue in a minimal example in this issue which led me to the current issue but I am not exactly sure what's going on now.
from pytorch.
I also re-ran the test in this issue on 8 x A100.
torch==2.1.2 transformers==4.40.2 accelerate==0.30.1
: ~41hrs.torch==2.3.0 transformers==4.40.2 accelerate==0.30.1
: ~41hrs.
from pytorch.
Sorry, I may have messed up my environments in the last couple of tests. Need to take a break now. Will post an update tomorrow.
from pytorch.
Related Issues (20)
- PyTorch DataLoader improvements for Iterable Dataset HOT 4
- [inductor][cpu]detectron2_maskrcnn_r_101_fpn detectron2_maskrcnn_r_50_c4 accuracy check failed when freezing flag is on. HOT 5
- What is the processing principle when the complex64 input tensor contains nan or inf for addition? HOT 1
- [PT2E Quantization] `prepare_pt2e` produces inconsistent data types for primitive int HOT 2
- Wrong Flag defined for ppc64le arch
- DISABLED test_comprehensive_special_bessel_y1_cuda_int32 (__main__.TestInductorOpInfoCUDA) HOT 5
- Setting a `float`, `complex` or `bool` type value to `precision` argument of `set_printoptions()` has problem when creating a `float`, `complex` or `bool` type tensor
- Fix torch._dynamo.exc.Unsupported: call_id withh args (UnspecializedNNModuleVariable() when TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 HOT 2
- [NestedTensor] RelaxedUnspecConstraint failures due to mark_dynamic in NT constructor HOT 1
- Lowering for the Average Pooling 3D backward operation HOT 1
- segfault when registering op with mismatched schema HOT 1
- UNSTABLE pull / linux-focal-cuda12.4-py3.10-gcc9-sm86 / build HOT 1
- UNSTABLE pull / linux-focal-cuda12.4-py3.10-gcc9 / build HOT 2
- `torch.einsum` docs don't mention that `opt_einsum` must be installed separately HOT 1
- Strict export fails when generating empty graph
- [dynamo] Detect mkldnn_max_pool2d/mkldnn view errors at tracing time HOT 1
- [dynamo] Module tracker test errors in dynamo at runtime HOT 1
- [dynamo] Activation checkpointing tests erroring at runtime
- Multiple Inputs/Outputs with torch.onnx.dynamo_export HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.