Skip to content

Commit

Permalink
optimize rank computation
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuofan1123 committed Jul 23, 2024
1 parent 4434149 commit 35f192b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 74 deletions.
69 changes: 14 additions & 55 deletions cpp/include/wholememory/device_reference.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,20 @@ class device_reference {
__device__ __forceinline__ explicit device_reference(const wholememory_gref_t& gref)
: pointer_(static_cast<DataTypeT*>(gref.pointer)),
typed_stride_(gref.stride / sizeof(DataTypeT)),
rank_memory_offsets_(gref.rank_memory_offsets),
world_size_(gref.world_size),
same_chunk_(gref.same_chunk),
estimated_stride_(0),
cache_rank_(0),
cache_offset_(0),
cache_size_(0)
same_chunk_(gref.same_chunk)
{
assert(gref.stride % sizeof(DataTypeT) == 0);
if (typed_stride_ > 0 && !same_chunk_) {
estimated_stride_ = rank_memory_offsets_[world_size_] / world_size_;
cache_rank_ = 0;
cache_offset_ = 0;
cache_size_ = rank_memory_offsets_[1] - rank_memory_offsets_[0];
if (typed_stride_ != 0 && !same_chunk_) {
assert(world_size_ <= 8); // intra-node WHOLEMEMORY_MT_CHUNKED
for (int i = 0; i < world_size_ + 1; i++) {
assert(gref.rank_memory_offsets[i] % sizeof(DataTypeT) == 0);
typed_rank_mem_offsets_[i] = gref.rank_memory_offsets[i] / sizeof(DataTypeT);
}
}
}
__device__ device_reference() = delete;

__device__ __forceinline__ size_t copy_offsets_to_shmem(char* shmem, size_t maxsize)
{
if (typed_stride_ == 0 || same_chunk_) return 0;
size_t used_shmem_size = (world_size_ + 1) * sizeof(size_t);
if (used_shmem_size > maxsize) return 0;
size_t* shmem_offsets = reinterpret_cast<size_t*>(shmem);
for (int i = threadIdx.x; i <= world_size_; i += blockDim.x) {
shmem_offsets[i] = rank_memory_offsets_[i];
}
__syncthreads();
rank_memory_offsets_ = shmem_offsets;
size_t aligned_used_shmem_size = ((used_shmem_size - 1) / 128 + 1) * 128;
return aligned_used_shmem_size;
}

__device__ __forceinline__ DataTypeT& operator[](size_t index)
{
if (typed_stride_ == 0) { return pointer_[index]; }
Expand All @@ -68,47 +49,25 @@ class device_reference {
return static_cast<DataTypeT**>(
static_cast<void*>(pointer_))[rank][index - rank * typed_stride_];
} else {
size_t rank = 0;
size_t offset = index * sizeof(DataTypeT);
if (offset >= cache_offset_ && offset < cache_offset_ + cache_size_) {
rank = cache_rank_;
} else {
int estimated_rank = max(world_size_ - 1, int(offset / estimated_stride_));
if (rank_memory_offsets_[estimated_rank] > offset) {
for (int i = estimated_rank - 1; i >= 0; i--) {
if (rank_memory_offsets_[i] <= offset) {
rank = i;
break;
}
}
} else {
for (int i = estimated_rank + 1; i <= world_size_; i++) {
if (rank_memory_offsets_[i] > offset) {
rank = i - 1;
break;
}
}
size_t rank = 0;
for (int i = 1; i < world_size_ + 1; i++) {
if (index < typed_rank_mem_offsets_[i]) {
rank = i - 1;
break;
}
cache_rank_ = rank;
cache_offset_ = rank_memory_offsets_[rank];
cache_size_ = rank_memory_offsets_[rank + 1] - rank_memory_offsets_[rank];
}
return static_cast<DataTypeT**>(
static_cast<void*>(pointer_))[rank][index - cache_offset_ / sizeof(DataTypeT)];
static_cast<void*>(pointer_))[rank][index - typed_rank_mem_offsets_[rank]];
}
}

private:
DataTypeT* pointer_;
size_t* rank_memory_offsets_;
int world_size_;
size_t typed_stride_;

size_t estimated_stride_;
bool same_chunk_;
int cache_rank_;
size_t cache_offset_;
size_t cache_size_;
size_t typed_rank_mem_offsets_[8 + 1];
};

} // namespace wholememory
22 changes: 15 additions & 7 deletions cpp/src/wholememory_ops/functions/bucket_ids_func.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
namespace wholememory_ops {

template <typename IndexT>
__device__ int dest_rank(IndexT entry_idx,
size_t total_entry_count,
const size_t* embedding_entry_offsets,
int world_size)
__device__ __forceinline__ int dest_rank(IndexT entry_idx,
size_t total_entry_count,
const size_t* embedding_entry_offsets,
int world_size)
{
size_t estimated_entry_per_rank = total_entry_count / world_size;
int estimated_rank = max(world_size - 1, int(entry_idx / estimated_entry_per_rank));
Expand Down Expand Up @@ -60,13 +60,18 @@ __global__ void bucket_ids_for_ranks_kernel(const IndexT* indices,
for (int idx = threadIdx.x; idx < world_size; idx += blockDim.x) {
rank_count_shared[idx] = 0;
}
size_t* embedding_entry_offsets_shared =
reinterpret_cast<size_t*>(shmem + sizeof(size_t) * world_size);
for (int idx = threadIdx.x; idx < world_size + 1; idx += blockDim.x) {
embedding_entry_offsets_shared[idx] = embedding_entry_offsets[idx];
}
__syncthreads();
size_t total_entry_count = embedding_entry_offsets[world_size];
size_t total_entry_count = embedding_entry_offsets_shared[world_size];
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count;
idx += blockDim.x * gridDim.x) {
IndexT node_idx = indices[idx];
if (node_idx < 0) continue;
int rank = dest_rank(node_idx, total_entry_count, embedding_entry_offsets, world_size);
int rank = dest_rank(node_idx, total_entry_count, embedding_entry_offsets_shared, world_size);
assert(rank >= 0 && rank < world_size);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
atomicAdd_block(&rank_count_shared[rank], 1);
Expand Down Expand Up @@ -95,7 +100,10 @@ void bucket_ids_for_ranks_temp_fn(void* indices,
block_count = std::min(block_count, sm_count * 4);
IndexT* indices_ptr = static_cast<IndexT*>(indices);
indices_ptr += indice_desc.storage_offset;
bucket_ids_for_ranks_kernel<<<block_count, BLOCK_SIZE, sizeof(int) * world_size, stream>>>(
bucket_ids_for_ranks_kernel<<<block_count,
BLOCK_SIZE,
sizeof(size_t) * (world_size * 2 + 1),
stream>>>(
indices_ptr, indice_desc.size, dev_rank_id_count_ptr, embedding_entry_offsets, world_size);
}

Expand Down
18 changes: 6 additions & 12 deletions cpp/src/wholememory_ops/functions/gather_scatter_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,12 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref,
int64_t output_stride = output_desc.stride;

wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);
size_t used_shm_size = embedding_dev_ref.copy_offsets_to_shmem(shm_in_char, shm_max_size);

typed_data_vector<EmbeddingT, ALIGNMENT> embeddings;
typed_data_vector<OutputT, ALIGNMENT> outputs;

int shm_size = (shm_max_size - used_shm_size) / sizeof(OutputT);
OutputT* all_sh = reinterpret_cast<OutputT*>(shm_in_char + used_shm_size);
int shm_size = shm_max_size / sizeof(OutputT);
OutputT* all_sh = reinterpret_cast<OutputT*>(shm_in_char);
OutputT* my_shared;
bool use_shm = true;
if (shm_size / (blockDim.x / 32) < output_desc.sizes[1]) { //
Expand Down Expand Up @@ -346,10 +345,7 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref,

int lane_id_in_sub_warp = subwarp.thread_rank();

constexpr size_t shm_max_size = 1024 * sizeof(size_t);
__shared__ char shmem[shm_max_size];
wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);
embedding_dev_ref.copy_offsets_to_shmem(shmem, shm_max_size);

int embedding_size = embedding_desc.sizes[1];
int64_t embedding_stride = embedding_desc.stride;
Expand All @@ -365,11 +361,10 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref,
if (embedding_table_idx < 0) continue;
int64_t embedding_offset =
embedding_desc.storage_offset + embedding_table_idx * embedding_stride;

EmbeddingT* emb_ptr = &embedding_dev_ref[embedding_offset];
for (int emb_idx = lane_id_in_sub_warp * ALIGNMENT; emb_idx < embedding_size;
emb_idx += ALIGNMENT * SUB_WARP_SIZE) {
mov_data<sizeof(EmbeddingT) * ALIGNMENT>(&embeddings,
&embedding_dev_ref[embedding_offset + emb_idx]);
mov_data<sizeof(EmbeddingT) * ALIGNMENT>(&embeddings, &emb_ptr[emb_idx]);
#pragma unroll
for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) {
typed_data_vector_at(outputs, sub_idx) =
Expand Down Expand Up @@ -542,10 +537,9 @@ __global__ void scatter_func_kernel(const InputT* input,
int async_copy_align = sizeof(InputT) > 4 ? 1 : 4 / sizeof(InputT);

wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);
size_t used_shm_size = embedding_dev_ref.copy_offsets_to_shmem(shm_in_char, shm_max_size);

int shm_size = (shm_max_size - used_shm_size) / sizeof(InputT);
InputT* all_sh = reinterpret_cast<InputT*>(shm_in_char + used_shm_size);
int shm_size = shm_max_size / sizeof(InputT);
InputT* all_sh = reinterpret_cast<InputT*>(shm_in_char);
InputT* my_shared;
int batch_size = (shm_size / (blockDim.x / 32) - async_copy_align) /
input_stride; // indices batch size in lines
Expand Down

0 comments on commit 35f192b

Please sign in to comment.