א

Adding p2p comms to torch.compile

Recent Pytorch-related releases might make some of the discussion here irrelevant. I do think that people – as long as the experimental features recently published aren’t in Pytorch main – will keep relying on Pytorch’s compilation capabilities (e.g. vLLM).

Overall, this post is me trying to get some ideas across and present my current understanding. Happy to receive feedback on where my understanding is off or what can be investigated further!

There is ample amount of content to deep-dive into the Pytorch codebase. We will not assume proficiency but overall familiarity with Pytorch and compilation here.

Big thank you to Tal Ben-Nun for mentoring and guidance for this project.

Introduction

I’ve recently opened a PR to allow for Pytorch’s compilation pipeline to understand pointwise communications.

Curious, one might say, as most operators we pass through (either forwards or during autodiff) are collectives, so why would you want this specific feature? Different applications come to mind. I will offer two applications that are often used in HPC and ML.

  1. Halo exchanges

  2. RingAttention & friends.

Current state

One could argue that using NCCL’s SymmetricMemory feature might be sufficient. We want to enable optimizations allowing for the scheduler infrastructure and cudagraphs to capture all paths that lead to lower walltime. I really want to exploit the entire Pytorch compilation pipeline to this end.

There’s been a considerable push for functionalization in Pytorch’s infrastructure to reduce complexity when reasoning about the different compilation stages, e.g. here or here. To that end, the collectives have been designed to be fully functional, see the Traceable functional collectives design document.

However, here’s an immediate issue. Consider the following function.

def f(x):
  return torch.ops._c10d_functional.all_reduce_coalesced_([x], "sum",)

As is per the usual Pytorchisms, the suffixed underscore indicates an in-place operation. Even though this function call technically makes part of the traceable functional collectives infrastructure, this will yield an error upon calling (any option) of torch.compile on it. (Technically you can call torch.compile on it but any subsequent execution of the compiled function will yield an error).

RuntimeError: Found a custom (non-ATen) operator whose output has alias annotations: _c10d_functional::all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!). We only support functionalizing operators whose outputs do not have alias annotations

In CS speak, in-place ops induce side effects. Side effects notably are forbidden in functional paradigms. I’m not entirely sure inductor’s scheduler tracks down any collective call to reduce the number of copies – there is a small number of issues suggesting otherwise.

My changes

In my PR, I’ve sort of abused the existing infrastructure for collectives to allow for pointwise communications to be traced using Dynamo. We construct temporary nodes in the graph for P2POps which are often used together with batch_isend_irecv to issue coalesced pointwise comms. The entire idea of this batching is to allow reducing overhead when creating new communicators (highly recommend looking inside torch/csrc/distributed/c10d and e.g. this file for deeper information).

We build P2POpVariables for P2POps that are verbosely passing the information to the compiled graph such that this

work = dist.batch_isend_irecv(
        [
          dist.P2POp(dist.isend, x, 0),
          dist.P2POp(dist.recv, y, 1)
        ])

some_work(x, z)

for w in work: w.wait()

can become something more akin to

tensors_to_wait = ops.batch_p2p_ops(
  ["isend", "irecv"],
  [0, 1],
  [x, y]
  )

some_work_compiled(x, z)

for t in tensors_to_wait:
    dist.wait_tensor(t)

In my PR, none of the ops appear in-place and we induce extra copies to allow for the operations to take place, see in torch.distributed._functional_collectives.py (compare the notation to what the error before told us about in-place operations):

"batch_p2p_ops(str[] op_list, int[] peer_list, int[] tag_list, Tensor[] tensors, str group_name) -> Tensor[]"

As a WIP this is fine – to allow for efficient interleaving and reducing expensive data movement we would want zero-copy comms though! I still want to change it to be fully in-place – as we’ve seen earlier, this is one of the reasons why we want to relax the functional requirement in the traceable functional collectives infrastructure.

Changes inducing bugs

So far it’s great, we can use this async comms feature to do distributed convolutions and whatever involves blocking halo exchanges. What if we want to interleave communication and computation?

Consider this code snippet:

def kernel(x0, x1, y0, y1):
  r = dist.get_rank()
  w = dist.get_world_size()
  nxt = (r + 1) % w
  prv = (r - 1) % w

  work = dist.batch_isend_irecv([
      dist.P2POp(dist.isend, x0, nxt),
      dist.P2POp(dist.irecv, y0, prv),
  ])
  t0 = x0 * 2 + 1
  for ww in work: ww.wait()
  a = y0 + t0

  work = dist.batch_isend_irecv([
      dist.P2POp(dist.isend, a, nxt),
      dist.P2POp(dist.irecv, y1, prv),
  ])
  t1 = a * 1.000244140625
  for ww in work: ww.wait()
  out = y1 + t1
  return out

It’s unfortunately a little involved but it’ll show us: The current implementation fails to correctly take care of dependencies wrt. data flow. Comparing eager and compiled runs of this function will yield garbage differences indicating something is off.

Let’s investigate:

TORCH_COMPILE_DEBUG=1 \
  TORCH_LOGS=output_code \
  TORCHINDUCTOR_CACHE_DIR=$PWD/inductor_cache/$RANK \
  TRITON_CACHE_DIR=$PWD/triton_cache/$RANK \
  torchrun --nproc-per-node=4 --standalone tester.py
# tester contains the above snippet and some numerics checks

Inspecting the inductor cache (here, explicitly set in the flags, inductor_cache/4y/c4yjdzyq3kj4da7e2xi5gnbvzd4laxipv66ohspyo4nuecce37pv.py):

buf0 = _c10d_functional.batch_p2p_ops(['isend','irecv'], [2,0], [0,0], [arg0_1, arg1_1], '0')
buf1 = buf0[0]
buf2 = buf0[1]
wait_tensor(buf1); wait_tensor(buf2)
triton_poi_fused_add_mul_0.run(arg1_1, arg0_1, arg2_1, buf7, buf15, 1048576, stream=stream1)

and only after the kernel call, it reads

buf8 = _c10d_functional.batch_p2p_ops(['isend','irecv'], [2,0], [0,0], [buf7, arg2_1], '0')
buf9 = buf8[0]; buf10 = buf8[1]
wait_tensor(buf9); wait_tensor(buf10)

Verdict: Inductor hoisted calls to a wrong argument. Inductor does not carry dependency correctly. Surely, we can try and track down the wrongly-assigned dependencies here. Or: Would making TFC allow for side effects be a smart choice to allow for Inductor to do it?

We would get zero-copy all_reduce and my newly-introduced pointwise comms would compile much more easily.

All this to say: My current additions need some work. Dependency management is broken and to allow for actual pipelining I need to fix this.