Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about send/recv scheduling #784

Open
libertyeagle opened this issue Feb 10, 2023 · 8 comments
Open

Question about send/recv scheduling #784

libertyeagle opened this issue Feb 10, 2023 · 8 comments

Comments

@libertyeagle
Copy link

libertyeagle commented Feb 10, 2023

Hi, I have a question about how P2P send/recv tasks are scheduled into kernel plans.
It seems in scheduleP2pTasksToPlan NCCL schedules send/recv tasks in a group according to a sendOrder and recvOrder that all peers have consensus on, i.e., at i-th loop, if rank r2's recvOrder[i] == r1, then we must have sendOrder[i] == r2 on rank r1.

Hence, for a specific i, we may have the following send/recv pattern (i=2 for intra-node):
1->3->5->1

rank 1: sendPeer=3, recvPeer=5
rank 3: sendPeer=5, recvPeer=1
rank 5: sendPeer=1, recvPeer=3

In this case, I wonder how send/recv ncclWorkElemP2p are scheduled that prevents deadlock. It seems when addP2pToPlan is called, we may schedule a recv ncclWorkElemP2p to a ncclWork, and a send ncclWorkElemP2p to a following ncclWork, hence send/recv will not execute in parallel. This could happen, e.g., if we only have 1 channel for P2P communication, and that channel's workQueue tail ncclWork has already have 8 p2pSend workelems.
We may have rank 1 wait recv for rank 5, rank 3 waits for rank 1, rank 5 waits for rank 3. The three ranks' send tasks resides in another ncclWork, hence will be blocked until the recv finishes, which does not seem to be the case.

Thanks!

@sjeaugey
Copy link
Member

Deadlock avoidance in send/recv is solely based on the principle of rotating pairs: when you send to rank + X (or node + X), you also receive at the same time from rank - X (or node - X).

Coupling each send to its symmetric receive splits any alltoall[v] operation into a series of rings which can't hang. In practice, we don't execute one ring at a time; we have many rings happening in parallel (p2p nChannels x 8, up to 256) so for relatively small sizes, everything is effectively happening in parallel.

@libertyeagle
Copy link
Author

libertyeagle commented Feb 10, 2023

Thanks a lot! Yes, I understand send/recv will be scheduled to a ring according to sendOrder and recvOrder. Parallel send/recv requires these send recv to be scheduled into the same ncclWork. But it seems under extreme condition, a paried recv and send on the ring can be schedule to different ncclWork?

https://github.com/NVIDIA/nccl/blob/master/src/enqueue.cc#L659

As recv and tasks are scheduled to ncclWork via two seperate calls of addP2pToPlan, could recv be scheduled into the tail ncclWork in a channel's workQueue, and when send is scheduled, a new ncclWork is allocated?

@sjeaugey
Copy link
Member

a paried recv and send on the ring can be schedule to different ncclWork

That should not be the case, although there is always a potential for bugs. I'd need to spend quite some time to get to this algorithm again and see whether there is a case where this could not be the case. Maybe @jbachan has a fresher view on this.

@libertyeagle
Copy link
Author

I think it might be a bug since when addP2pToPlan is called for scheduling a paired send/recv, it only checks if current ncclWork has enough p2p workElem for that type (send/recv). In a ncclWork, we have 8 slots for send (odd slots), 8 slots for recv (even slots).
If in the 0..(i-1)-th loop here (https://github.com/NVIDIA/nccl/blob/master/src/enqueue.cc#L622), a rank could schedule send to 8 different peers on a single channel, hence take up all 8 slots in the tail ncclWork in chan->workQueue.
Then in the i-th loop, if this rank needs to both send and recv. It will first schedule the recv, where none of the slots for recv has been used, it will reuse the tail ncclWork. However, when the paried send in the i-th loop is scheduled, it finds 8 send slots are all used, so it will allocate a new ncclWork for the send.

@libertyeagle
Copy link
Author

libertyeagle commented Feb 10, 2023

Let assume we modify the boundary check here (https://github.com/NVIDIA/nccl/blob/master/src/enqueue.cc#L209) from:
if (chan->p2pTailElem[elem->p2pType-1] < NCCL_MAX_WORK_ELEMENTS_P2P)
to
if (chan->p2pTailElem[elem->p2pType-1] < 3

Then in a ncclWork, we have 2 recv slots, but only 1 send slots. It does not affect the P2P task enqueue logic, but for easy of demonstrate the bug. Otherwise, I think we need at least 9 ranks to produce the bug as I mentioned above.

In this case, if we modify sendrecv test https://github.com/NVIDIA/nccl-tests/blob/master/src/sendrecv.cu#L46 to send and recv from itself before send/recv to peer:

  NCCLCHECK(ncclGroupStart());
  NCCLCHECK(ncclSend(sendbuff, count, type, rank, comm, stream));
  NCCLCHECK(ncclRecv(recvbuff, count, type, rank, comm, stream));
  NCCLCHECK(ncclSend(sendbuff, count, type, sendPeer, comm, stream));
  NCCLCHECK(ncclRecv(recvbuff, count, type, recvPeer, comm, stream));
  NCCLCHECK(ncclGroupEnd());

Then running with the following command on 2 GPUs will hang:
NCCL_MAX_NCHANNELS=1 ./build/sendrecv_perf -b 8 -e 128M -f 2 -g 1 -t 2
In this case, rank 0's send to rank 1 will be schedule to a following ncclWork than rank 0's recv from rank 1.

@sjeaugey
Copy link
Member

Let assume we modify the boundary check here [...] to [...] 3

Yes, that's why we have 8 sends and 8 receives per p2pWorkElem, and we can't replace a send by a receive or vice versa (receives always use even slots and sends use odd slots IIRC). So we always have a slot for a pair of send/receive.

@sjeaugey
Copy link
Member

I think it might be a bug since when addP2pToPlan is called for scheduling a paired send/recv, it only checks if current ncclWork has enough p2p workElem for that type (send/recv).

That's possible indeed when we have sends without a receive and then send/receive pairs. I'd let @jbachan review that possibility and opine.

@jbachan
Copy link
Collaborator

jbachan commented Feb 22, 2023

I haven't been able to convince myself that we are bug free, so I'll just blather about how I wish it worked:

Deadlock freedom can be achieved even if we processed at most one send OR recv at a time. To envision how, consider each matching pair of (send(), recv()) is a global collective involving all ranks, except all but 2 (1 sender + 1 receiver) have nothing to do, like an extremely sparse alltoallv. Let's introduce a new collective call with these semantics: ncclSendRecv(int sender, int receiver). Everybody must call ncclSendRecv() even if they are neither of the two ranks participating. Now that everything is in collective calls It's easy to make it run deadlock free, just have all ranks issue the ncclSendRecv()'s in the same order. Here's the simplest possible order, which is terrible for perf, but ignore that:

for (int sender=0; sender<nranks; sender++) {
  for (int recver=0; recver<nranks; recver++) {
    ncclSendRecv(sender, recver);
  }
}

Since the implementation of ncclSendRecv is just enqueueing send and recv tasks to the device only if it's participating in this send/recv, it's silly to require all ranks to issue all ncclSendRecv's, so we relax that to only ranks which are either the sender or receiver need to issue the call. Since not all ranks issue every ncclSendRecv, we can no longer demand they all issue the same ncclSendRecv's in the same order. Instead the new constraint is that there exists some global order of ncclSendRecv's, such that the order submitted by any chosen rank will never be in violation of it. We can now simplify the code to:

for (int sender=0; sender<nranks; sender++) {
  for (int recver=0; recver<nranks; recver++) {
    if (myrank == sender || myrank == recver) ncclSendRecv(sender, recver);
  }
}

With that we can just define ncclSend/Recv in terms of ncclSendRecv and resimplify like so:

inline void ncclSend(int recver) { ncclSendRecv(myrank, recver); }
inline void ncclRecv(int sender) { ncclSendRecv(sender, myrank); }
for (int sender=0; sender<nranks; sender++) {
  for (int recver=0; recver<nranks; recver++) {
    if (myrank == sender) ncclSend(recver);
    if (myrank == recver) ncclRecv(sender);
  }
}

And this hints at how I think NCCL should work, it should have one global SendRecv order against which all local ncclSend's and ncclRecv's are ordered such that even if the GPU were to process this list serially we would still be deadlock free. Adding parallelism to that is just for performance. We could store this order locally as an array of struct { uint32_t peer:31, sendNotRecv:1; };.

Unfortunately what we actually have are two separate orders int sendOrder[] and int recvOrder[] which know nothing of each other.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants