Skip to content

Commit

Permalink
fix_mnnvl_with_uuid (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
chuangz0 authored Aug 16, 2024
1 parent 604d7a8 commit bc13899
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions cpp/src/wholememory/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <memory>
#include <raft/core/error.hpp>
#include <string>
#include <wholememory/tensor_description.h>
#include <wholememory/wholememory.h>

Expand Down Expand Up @@ -557,7 +558,7 @@ void exchange_rank_info(wholememory_comm_t wm_comm)
wm_comm->clique_info.clique_rank = -1;
wm_comm->clique_info.clique_rank_num = 0;

std::set<int> clique_ids{};
std::set<std::string> clique_uuids{};

for (int r = 0; r < wm_comm->world_size; r++) {
WHOLEMEMORY_CHECK(r == p_rank_info.get()[r].rank);
Expand All @@ -583,16 +584,21 @@ void exchange_rank_info(wholememory_comm_t wm_comm)
if (wm_comm->clique_info.clique_rank_num == 0) { wm_comm->clique_info.clique_first_rank = r; }
wm_comm->clique_info.clique_rank_num++;
}
clique_ids.insert(p_rank_info.get()[r].fabric_info.cliqueId);
clique_uuids.insert(
std::string(reinterpret_cast<const char*>(p_rank_info.get()[r].fabric_info.clusterUuid),
NVML_GPU_FABRIC_UUID_LEN));

#endif
}

#if CUDA_VERSION >= 12030
wm_comm->clique_info.clique_num = clique_ids.size();
int id = 0;
for (auto clique_id : clique_ids) {
if (clique_id == ri.fabric_info.cliqueId) { wm_comm->clique_info.clique_id = id; }
wm_comm->clique_info.clique_num = clique_uuids.size();

std::string uuid = std::string(reinterpret_cast<const char*>(ri.fabric_info.clusterUuid),
NVML_GPU_FABRIC_UUID_LEN);
int id = 0;
for (auto clique_uuid : clique_uuids) {
if (clique_uuid == uuid) { wm_comm->clique_info.clique_id = id; }
id++;
}

Expand Down

0 comments on commit bc13899

Please sign in to comment.