Skip to content

Commit

Permalink
2.18.1-1
Browse files Browse the repository at this point in the history
Add support for IB SHARP to NVLS (NVLink SHARP algorithm).
Add NVLS+Tree algorithm.
Add support for memory management using cuMem* functions.
Use all NICs for Send/Receive operations on systems with more than
one NIC per GPU (#804).
Add ncclCommSplit primitive, with resource sharing option in config.
Fix alltoallv hang (#788)
Increase number of channels on H100 when we're not limited by NVLink.
Improve error reporting in case of IB failure, printing local and
remote ID (#779).
Add build option to allow compilation against RDMA includes instead
of dynamically loading IB verbs symbols (#802).
Fix context creation for progress thread (#803).
NET/IB: add option to use multiple QPs in round-robin mode.
Fix tree performance issue when NVB is disabled on HCM topologies.
  • Loading branch information
sjeaugey committed Apr 18, 2023
1 parent 9b7d5ed commit d97a32f
Show file tree
Hide file tree
Showing 64 changed files with 4,752 additions and 3,125 deletions.
5 changes: 5 additions & 0 deletions makefiles/common.mk
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ DEBUG ?= 0
TRACE ?= 0
PROFAPI ?= 1
NVTX ?= 1
RDMA_CORE ?= 0

NVCC = $(CUDA_HOME)/bin/nvcc

Expand Down Expand Up @@ -106,3 +107,7 @@ endif
ifneq ($(PROFAPI), 0)
CXXFLAGS += -DPROFAPI
endif

ifneq ($(RDMA_CORE), 0)
CXXFLAGS += -DNCCL_BUILD_RDMA_CORE=1
endif
2 changes: 1 addition & 1 deletion makefiles/version.mk
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
##### version
NCCL_MAJOR := 2
NCCL_MINOR := 17
NCCL_MINOR := 18
NCCL_PATCH := 1
NCCL_SUFFIX :=
PKG_REVISION := 1
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ include ../makefiles/version.mk
##### src files
INCEXPORTS := nccl.h nccl_net.h
LIBSRCFILES := init.cc init_nvtx.cc channel.cc bootstrap.cc transport.cc enqueue.cc group.cc debug.cc proxy.cc net.cc \
misc/cudawrap.cc misc/nvmlwrap.cc misc/ibvwrap.cc misc/gdrwrap.cc \
misc/cudawrap.cc misc/nvmlwrap.cc misc/ibvsymbols.cc misc/ibvwrap.cc misc/gdrwrap.cc \
misc/utils.cc misc/argcheck.cc misc/socket.cc misc/shmutils.cc misc/profiler.cc misc/param.cc misc/strongstream.cc \
misc/ipcsocket.cc \
transport/p2p.cc transport/shm.cc transport/net.cc transport/net_socket.cc transport/net_ib.cc transport/coll_net.cc transport/nvls.cc \
Expand Down
72 changes: 70 additions & 2 deletions src/bootstrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,74 @@ ncclResult_t bootstrapInit(struct ncclBootstrapHandle* handle, struct ncclComm*
return ncclSuccess;
}

ncclResult_t bootstrapSplit(struct ncclBootstrapHandle* handle, struct ncclComm* comm, struct ncclComm* parent, int color, int key, int* parentRanks) {
ncclResult_t ret = ncclSuccess;
int rank = comm->rank;
int nranks = comm->nRanks;
int prev, next;
ncclSocketAddress listenAddr, tmpAddr;
struct ncclSocket* proxySocket;
struct bootstrapState* state;

NCCLCHECKGOTO(ncclCalloc(&state, 1), ret, fail);
state->rank = rank;
state->nranks = nranks;
state->abortFlag = comm->abortFlag;
comm->bootstrap = state;
comm->magic = state->magic = handle->magic;

prev = parentRanks[(rank-1+nranks)%nranks];
next = parentRanks[(rank+1)%nranks];

// Setup my sockets for the allgather ring and other p2p connections
NCCLCHECKGOTO(ncclSocketInit(&state->listenSock, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag, 0), ret, fail);
NCCLCHECKGOTO(ncclSocketInit(&state->ringRecvSocket, NULL, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag, 0), ret, fail);

// Create socket for other ranks to contact me
NCCLCHECKGOTO(ncclSocketListen(&state->listenSock), ret, fail);

// Get addr from next rank
NCCLCHECKGOTO(ncclSocketGetAddr(&state->listenSock, &listenAddr), ret, fail);
NCCLCHECKGOTO(bootstrapSend(parent->bootstrap, prev, -2, &listenAddr, sizeof(union ncclSocketAddress)), ret, fail);
NCCLCHECKGOTO(bootstrapRecv(parent->bootstrap, next, -2, &tmpAddr, sizeof(union ncclSocketAddress)), ret, fail);

NCCLCHECKGOTO(ncclSocketInit(&state->ringSendSocket, &tmpAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag, 0), ret, fail);
NCCLCHECKGOTO(ncclSocketConnect(&state->ringSendSocket), ret, fail);
// Accept the connect request from the previous rank in the AllGather ring
NCCLCHECKGOTO(ncclSocketAccept(&state->ringRecvSocket, &state->listenSock), ret, fail);

// AllGather all listen handlers
NCCLCHECKGOTO(ncclCalloc(&state->peerCommAddresses, nranks), ret, fail);
memcpy(state->peerCommAddresses+rank, &listenAddr, sizeof(union ncclSocketAddress));
NCCLCHECKGOTO(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union ncclSocketAddress)), ret, fail);

if (parent->config.splitShare) {
/* map local rank to top parent local rank. */
for (int i = 0; i < nranks; ++i) {
comm->topParentRanks[i] = parent->topParentRanks[parentRanks[i]];
}
comm->proxyState = parent->sharedRes->proxyState;
ncclAtomicRefCountIncrement(&parent->sharedRes->proxyState->refCount);
} else {
// Create the service proxy
NCCLCHECKGOTO(ncclCalloc(&state->peerProxyAddresses, nranks), ret, fail);
NCCLCHECKGOTO(ncclCalloc(&proxySocket, 1), ret, fail);
NCCLCHECKGOTO(ncclSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeProxy, comm->abortFlag, 0), ret, fail);
NCCLCHECKGOTO(ncclSocketListen(proxySocket), ret, fail);
NCCLCHECKGOTO(ncclSocketGetAddr(proxySocket, &tmpAddr), ret, fail);
memcpy(state->peerProxyAddresses + rank, &tmpAddr, sizeof(union ncclSocketAddress));
NCCLCHECKGOTO(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union ncclSocketAddress)), ret, fail);
NCCLCHECKGOTO(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses), ret, fail);
}

INFO(NCCL_INIT, "bootstrapSplit: rank %d nranks %d color %d key %d prev %d next %d - DONE", rank, nranks, color, key, prev, next);

exit:
return ret;
fail:
goto exit;
}

ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
struct bootstrapState* state = (struct bootstrapState*)commState;
char* data = (char*)allData;
Expand Down Expand Up @@ -336,7 +404,7 @@ ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int s
struct bootstrapState* state = (struct bootstrapState*)commState;
struct ncclSocket sock;

NCCLCHECKGOTO(ncclSocketInit(&sock, state->peerCommAddresses+peer, state->magic, ncclSocketTypeBootstrap, state->abortFlag), ret, fail);
NCCLCHECKGOTO(ncclSocketInit(&sock, state->peerCommAddresses+peer, state->magic, ncclSocketTypeBootstrap), ret, fail);
NCCLCHECKGOTO(ncclSocketConnect(&sock), ret, fail);
NCCLCHECKGOTO(bootstrapNetSend(&sock, &state->rank, sizeof(int)), ret, fail);
NCCLCHECKGOTO(bootstrapNetSend(&sock, &tag, sizeof(int)), ret, fail);
Expand Down Expand Up @@ -397,7 +465,7 @@ ncclResult_t bootstrapIntraNodeBroadcast(void* commState, int *ranks, int rank,
}
}
else {
NCCLCHECK(bootstrapRecv(commState, ranks[root], /*tag=*/rank, bcastData, size));
NCCLCHECK(bootstrapRecv(commState, ranks[root], /*tag=*/ranks[rank], bcastData, size));
}

TRACE(NCCL_INIT, "rank %d nranks %d root %d size %d - DONE", rank, nranks, root, size);
Expand Down
141 changes: 118 additions & 23 deletions src/channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,49 +17,144 @@ ncclResult_t initChannel(struct ncclComm* comm, int channelId) {
channel->id = channelId;
channel->workFifoSent = 0;

NCCLCHECK(ncclStrongStreamAcquireUncaptured(&comm->deviceStream));
struct ncclSharedResources* sharedRes = comm->sharedRes;

// The extra on nRanks+1 is for collnet root (i.e. network)
channel->peers = ncclMemoryStackAlloc<struct ncclChannelPeer>(&comm->memPermanent, nPeers);
NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nPeers, comm->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, channel->devPeers);
NCCLCHECK(ncclStrongStreamAcquireUncaptured(&sharedRes->deviceStream));

if (channel->peers == NULL) {
// The extra on nRanks+1 is for collnet root (i.e. network)
// Allocate everything related to sharedRes with ncclCalloc as this can be
// shared between communicators hence should not be tied to comm.
if (sharedRes->peers[channelId] == NULL) {
NCCLCHECK(ncclCalloc(sharedRes->peers + channelId, sharedRes->tpNRanks));
}
channel->peers = ncclMemoryStackAlloc<struct ncclChannelPeer*>(&comm->memPermanent, nPeers);
for (int r = 0; r < nRanks; r++) {
channel->peers[r] = comm->sharedRes->peers[channelId] + comm->topParentRanks[r];
ncclAtomicRefCountIncrement(&channel->peers[r]->refCount);
}
}

if (channel->devPeers == NULL) {
if (sharedRes->devPeers[channelId] == NULL) {
NCCLCHECK(ncclCudaCallocAsync(sharedRes->devPeers + channelId, sharedRes->tpNRanks, sharedRes->deviceStream.cudaStream));
}
/* channel->devPeers is not shared, so just free it when calling commFree() */
NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nPeers, sharedRes->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, channel->devPeers);
for (int r = 0; r < nRanks; r++) {
uintptr_t addr = (uintptr_t)(comm->sharedRes->devPeers[channelId] + comm->topParentRanks[r]);
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + r), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
}
}

channel->ring.userRanks = ncclMemoryStackAlloc<int>(&comm->memPermanent, nRanks);
NCCLCHECK(ncclCudaCallocAsync(&channel->devRingUserRanks, nRanks, comm->deviceStream.cudaStream));
NCCLCHECK(ncclCudaCallocAsync(&channel->devRingUserRanks, nRanks, sharedRes->deviceStream.cudaStream));
ncclCommPushCudaFree(comm, channel->devRingUserRanks);

NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream));
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &sharedRes->deviceStream));

return ncclSuccess;
}

ncclResult_t initNvlsChannel(struct ncclComm* comm, int channelId, struct ncclComm* parent, bool share) {
struct ncclChannel* channel = &comm->channels[channelId];
struct ncclSharedResources* sharedRes = comm->sharedRes;

if (channel->nvlsPeers != NULL)
return ncclSuccess;

if (channel->id == -1)
NCCLCHECK(initChannel(comm, channelId));

NCCLCHECK(ncclStrongStreamAcquireUncaptured(&sharedRes->deviceStream));

for (int r=0; r < nPeers; ++r) {
for (int b=0; b < NCCL_MAX_CONNS; b++) {
channel->peers[r].send[b].comm = comm;
channel->peers[r].recv[b].comm = comm;
if (share) {
channel->nvlsPeers = parent->channels[channelId].nvlsPeers;
channel->nvlsDevPeers = parent->channels[channelId].nvlsDevPeers;
for (int r = 0; r < comm->localRanks; ++r) {
int tr = comm->topParentLocalRanks[r];
uintptr_t addr = (uintptr_t)(parent->channels[channelId].nvlsDevPeers + tr);
channel->peers[comm->nRanks + 1 + r] = parent->channels[channelId].nvlsPeers + tr;
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks + 1 + r), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
ncclAtomicRefCountIncrement(&parent->channels[channelId].nvlsPeers[tr].refCount);
}
} else {
NCCLCHECK(ncclCalloc(&channel->nvlsPeers, comm->localRanks));
NCCLCHECK(ncclCudaCallocAsync(&channel->nvlsDevPeers, comm->localRanks, sharedRes->deviceStream.cudaStream));
for (int r = 0; r < comm->localRanks; ++r) {
uintptr_t addr = (uintptr_t)(channel->nvlsDevPeers + r);
channel->peers[comm->nRanks + 1 + r] = channel->nvlsPeers + r;
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks + 1 + r), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
ncclAtomicRefCountIncrement(&channel->nvlsPeers[r].refCount);
}
}

NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &sharedRes->deviceStream));

return ncclSuccess;
}

ncclResult_t initCollnetChannel(struct ncclComm* comm, int channelId, struct ncclComm* parent, bool share) {
struct ncclChannel* channel = &comm->channels[channelId];
struct ncclSharedResources* sharedRes = comm->sharedRes;
uintptr_t addr;

if (channel->collnetPeers != NULL)
return ncclSuccess;

if (channel->id == -1)
NCCLCHECK(initChannel(comm, channelId));

NCCLCHECK(ncclStrongStreamAcquireUncaptured(&sharedRes->deviceStream));

if (share) {
channel->collnetPeers = parent->channels[channelId].collnetPeers;
channel->collnetDevPeers = parent->channels[channelId].collnetDevPeers;
addr = (uintptr_t)parent->channels[channelId].collnetDevPeers;
channel->peers[comm->nRanks] = parent->channels[channelId].collnetPeers;
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
ncclAtomicRefCountIncrement(&parent->channels[channelId].collnetPeers->refCount);
} else {
NCCLCHECK(ncclCalloc(&channel->collnetPeers, 1));
NCCLCHECK(ncclCudaCallocAsync(&channel->collnetDevPeers, 1, sharedRes->deviceStream.cudaStream));
addr = (uintptr_t)channel->collnetDevPeers;
channel->peers[comm->nRanks] = channel->collnetPeers;
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
ncclAtomicRefCountIncrement(&channel->collnetPeers->refCount);
}

NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &sharedRes->deviceStream));

return ncclSuccess;
}

ncclResult_t freeChannel(struct ncclChannel* channel, int nRanks) {
ncclResult_t freeChannel(struct ncclChannel* channel, int nRanks, int collnetNRanks, int nvlsNRanks) {
int nPeers = nRanks + collnetNRanks + nvlsNRanks;
/* channel peers are only valid when async init thread completes commAlloc() and
* the channel is intialized with initChannel(); if either is not done, this channel
* should never be free. */
if (channel->id == -1 || channel->peers == NULL) return ncclSuccess;

// Free transport proxy resources
// Note: free all send resources first due to CollNet arrangement
for (int r=0; r<nRanks+1; r++) {
struct ncclChannelPeer* peer = channel->peers+r;
for (int b=0; b<NCCL_MAX_CONNS; b++) {
if (peer->send[b].transportComm) NCCLCHECK(peer->send[b].transportComm->free(peer->send+b));
}
}
for (int r=0; r<nRanks+1; r++) {
struct ncclChannelPeer* peer = channel->peers+r;
for (int b=0; b<NCCL_MAX_CONNS; b++) {
if (peer->recv[b].transportComm) NCCLCHECK(peer->recv[b].transportComm->free(peer->recv+b));
for (int r = 0; r < nPeers; r++) {
struct ncclChannelPeer* peer = channel->peers[r];
if (peer) {
if (ncclAtomicRefCountDecrement(&peer->refCount) == 0) {
for (int b=0; b<NCCL_MAX_CONNS; b++) {
if (peer->send[b].transportComm) NCCLCHECK(peer->send[b].transportComm->free(peer->send+b));
if (peer->recv[b].transportComm) NCCLCHECK(peer->recv[b].transportComm->free(peer->recv+b));
}
if (r == nRanks) {
free(channel->collnetPeers);
ncclCudaFree(channel->collnetDevPeers);
} else if (r == nPeers - 1) {
free(channel->nvlsPeers);
ncclCudaFree(channel->nvlsDevPeers);
}
}
}
}

return ncclSuccess;
}
14 changes: 7 additions & 7 deletions src/collectives/device/all_gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ namespace {
if (inputBuf + chunkOffset == outputBuf + offset) { // In place
prims.directSend(chunkOffset, offset, nelem);
} else {
prims.directCopySend(chunkOffset, offset, offset, nelem);
prims.directCopySend(chunkOffset, offset, nelem);
}

// k-2 steps: copy to next GPU
for (int j=1; j<nranks-1; ++j) {
rankDest = ringRanks[nranks-j];
offset = chunkOffset + rankDest * size;

prims.directRecvCopySend(offset, offset, nelem);
prims.directRecvCopySend(offset, nelem);
}

// Make final copy from buffer to dest.
Expand Down Expand Up @@ -118,19 +118,19 @@ struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SI

if (tid < tidEndGather) {
// Gather
int group = (0*Proto::MaxGroupWidth) | (0<<16);
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff, args->redOpArg, group, args);
prims(tid, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff,
args->redOpArg, 0*Proto::MaxGroupWidth, 0, 0);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*chunkSize;
int nelem = min(chunkSize, size-offset);
prims.gather(offset, nvls->nHeads*size, nelem, size, -1, 0);
}
} else if (tid < tidEndBcast) {
int group = (3*Proto::MaxGroupWidth) | (1<<16);
// Bcast through MC
// Bcast through NVLS
Primitives<T, RedOp, FanAsymmetric<0, 1>, /*Direct=*/0, Proto, 0>
prims(tid-tidEndGather, nThreadsBcast, NULL, &nvls->down, args->sendbuff, NULL, args->redOpArg, group, args);
prims(tid-tidEndGather, nThreadsBcast, NULL, &nvls->down, args->sendbuff, NULL,
args->redOpArg, 3*Proto::MaxGroupWidth, 1, 1);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
ssize_t offset = gridOffset + bid*chunkSize;
int nelem = min(chunkSize, size-offset);
Expand Down
Loading

0 comments on commit d97a32f

Please sign in to comment.