diff --git a/cpp/include/wholememory/device_reference.cuh b/cpp/include/wholememory/device_reference.cuh index 4ffde7d44..8f2146ae9 100644 --- a/cpp/include/wholememory/device_reference.cuh +++ b/cpp/include/wholememory/device_reference.cuh @@ -27,39 +27,20 @@ class device_reference { __device__ __forceinline__ explicit device_reference(const wholememory_gref_t& gref) : pointer_(static_cast(gref.pointer)), typed_stride_(gref.stride / sizeof(DataTypeT)), - rank_memory_offsets_(gref.rank_memory_offsets), world_size_(gref.world_size), - same_chunk_(gref.same_chunk), - estimated_stride_(0), - cache_rank_(0), - cache_offset_(0), - cache_size_(0) + same_chunk_(gref.same_chunk) { assert(gref.stride % sizeof(DataTypeT) == 0); - if (typed_stride_ > 0 && !same_chunk_) { - estimated_stride_ = rank_memory_offsets_[world_size_] / world_size_; - cache_rank_ = 0; - cache_offset_ = 0; - cache_size_ = rank_memory_offsets_[1] - rank_memory_offsets_[0]; + if (typed_stride_ != 0 && !same_chunk_) { + assert(world_size_ <= 8); // intra-node WHOLEMEMORY_MT_CHUNKED + for (int i = 0; i < world_size_ + 1; i++) { + assert(gref.rank_memory_offsets[i] % sizeof(DataTypeT) == 0); + typed_rank_mem_offsets_[i] = gref.rank_memory_offsets[i] / sizeof(DataTypeT); + } } } __device__ device_reference() = delete; - __device__ __forceinline__ size_t copy_offsets_to_shmem(char* shmem, size_t maxsize) - { - if (typed_stride_ == 0 || same_chunk_) return 0; - size_t used_shmem_size = (world_size_ + 1) * sizeof(size_t); - if (used_shmem_size > maxsize) return 0; - size_t* shmem_offsets = reinterpret_cast(shmem); - for (int i = threadIdx.x; i <= world_size_; i += blockDim.x) { - shmem_offsets[i] = rank_memory_offsets_[i]; - } - __syncthreads(); - rank_memory_offsets_ = shmem_offsets; - size_t aligned_used_shmem_size = ((used_shmem_size - 1) / 128 + 1) * 128; - return aligned_used_shmem_size; - } - __device__ __forceinline__ DataTypeT& operator[](size_t index) { if (typed_stride_ == 0) { return pointer_[index]; } @@ -68,47 +49,25 @@ class device_reference { return static_cast( static_cast(pointer_))[rank][index - rank * typed_stride_]; } else { - size_t rank = 0; - size_t offset = index * sizeof(DataTypeT); - if (offset >= cache_offset_ && offset < cache_offset_ + cache_size_) { - rank = cache_rank_; - } else { - int estimated_rank = max(world_size_ - 1, int(offset / estimated_stride_)); - if (rank_memory_offsets_[estimated_rank] > offset) { - for (int i = estimated_rank - 1; i >= 0; i--) { - if (rank_memory_offsets_[i] <= offset) { - rank = i; - break; - } - } - } else { - for (int i = estimated_rank + 1; i <= world_size_; i++) { - if (rank_memory_offsets_[i] > offset) { - rank = i - 1; - break; - } - } + size_t rank = 0; + for (int i = 1; i < world_size_ + 1; i++) { + if (index < typed_rank_mem_offsets_[i]) { + rank = i - 1; + break; } - cache_rank_ = rank; - cache_offset_ = rank_memory_offsets_[rank]; - cache_size_ = rank_memory_offsets_[rank + 1] - rank_memory_offsets_[rank]; } return static_cast( - static_cast(pointer_))[rank][index - cache_offset_ / sizeof(DataTypeT)]; + static_cast(pointer_))[rank][index - typed_rank_mem_offsets_[rank]]; } } private: DataTypeT* pointer_; - size_t* rank_memory_offsets_; int world_size_; size_t typed_stride_; - size_t estimated_stride_; bool same_chunk_; - int cache_rank_; - size_t cache_offset_; - size_t cache_size_; + size_t typed_rank_mem_offsets_[8 + 1]; }; } // namespace wholememory diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_func.cu b/cpp/src/wholememory_ops/functions/bucket_ids_func.cu index f5b25390b..6bd6b6c44 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_func.cu +++ b/cpp/src/wholememory_ops/functions/bucket_ids_func.cu @@ -29,10 +29,10 @@ namespace wholememory_ops { template -__device__ int dest_rank(IndexT entry_idx, - size_t total_entry_count, - const size_t* embedding_entry_offsets, - int world_size) +__device__ __forceinline__ int dest_rank(IndexT entry_idx, + size_t total_entry_count, + const size_t* embedding_entry_offsets, + int world_size) { size_t estimated_entry_per_rank = total_entry_count / world_size; int estimated_rank = max(world_size - 1, int(entry_idx / estimated_entry_per_rank)); @@ -60,13 +60,18 @@ __global__ void bucket_ids_for_ranks_kernel(const IndexT* indices, for (int idx = threadIdx.x; idx < world_size; idx += blockDim.x) { rank_count_shared[idx] = 0; } + size_t* embedding_entry_offsets_shared = + reinterpret_cast(shmem + sizeof(size_t) * world_size); + for (int idx = threadIdx.x; idx < world_size + 1; idx += blockDim.x) { + embedding_entry_offsets_shared[idx] = embedding_entry_offsets[idx]; + } __syncthreads(); - size_t total_entry_count = embedding_entry_offsets[world_size]; + size_t total_entry_count = embedding_entry_offsets_shared[world_size]; for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count; idx += blockDim.x * gridDim.x) { IndexT node_idx = indices[idx]; if (node_idx < 0) continue; - int rank = dest_rank(node_idx, total_entry_count, embedding_entry_offsets, world_size); + int rank = dest_rank(node_idx, total_entry_count, embedding_entry_offsets_shared, world_size); assert(rank >= 0 && rank < world_size); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 atomicAdd_block(&rank_count_shared[rank], 1); @@ -95,7 +100,10 @@ void bucket_ids_for_ranks_temp_fn(void* indices, block_count = std::min(block_count, sm_count * 4); IndexT* indices_ptr = static_cast(indices); indices_ptr += indice_desc.storage_offset; - bucket_ids_for_ranks_kernel<<>>( + bucket_ids_for_ranks_kernel<<>>( indices_ptr, indice_desc.size, dev_rank_id_count_ptr, embedding_entry_offsets, world_size); } diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index fd6d0f8d5..140b257f8 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -272,13 +272,12 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, int64_t output_stride = output_desc.stride; wholememory::device_reference embedding_dev_ref(embedding_gref); - size_t used_shm_size = embedding_dev_ref.copy_offsets_to_shmem(shm_in_char, shm_max_size); typed_data_vector embeddings; typed_data_vector outputs; - int shm_size = (shm_max_size - used_shm_size) / sizeof(OutputT); - OutputT* all_sh = reinterpret_cast(shm_in_char + used_shm_size); + int shm_size = shm_max_size / sizeof(OutputT); + OutputT* all_sh = reinterpret_cast(shm_in_char); OutputT* my_shared; bool use_shm = true; if (shm_size / (blockDim.x / 32) < output_desc.sizes[1]) { // @@ -346,10 +345,7 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, int lane_id_in_sub_warp = subwarp.thread_rank(); - constexpr size_t shm_max_size = 1024 * sizeof(size_t); - __shared__ char shmem[shm_max_size]; wholememory::device_reference embedding_dev_ref(embedding_gref); - embedding_dev_ref.copy_offsets_to_shmem(shmem, shm_max_size); int embedding_size = embedding_desc.sizes[1]; int64_t embedding_stride = embedding_desc.stride; @@ -365,11 +361,10 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, if (embedding_table_idx < 0) continue; int64_t embedding_offset = embedding_desc.storage_offset + embedding_table_idx * embedding_stride; - + EmbeddingT* emb_ptr = &embedding_dev_ref[embedding_offset]; for (int emb_idx = lane_id_in_sub_warp * ALIGNMENT; emb_idx < embedding_size; emb_idx += ALIGNMENT * SUB_WARP_SIZE) { - mov_data(&embeddings, - &embedding_dev_ref[embedding_offset + emb_idx]); + mov_data(&embeddings, &emb_ptr[emb_idx]); #pragma unroll for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) { typed_data_vector_at(outputs, sub_idx) = @@ -542,10 +537,9 @@ __global__ void scatter_func_kernel(const InputT* input, int async_copy_align = sizeof(InputT) > 4 ? 1 : 4 / sizeof(InputT); wholememory::device_reference embedding_dev_ref(embedding_gref); - size_t used_shm_size = embedding_dev_ref.copy_offsets_to_shmem(shm_in_char, shm_max_size); - int shm_size = (shm_max_size - used_shm_size) / sizeof(InputT); - InputT* all_sh = reinterpret_cast(shm_in_char + used_shm_size); + int shm_size = shm_max_size / sizeof(InputT); + InputT* all_sh = reinterpret_cast(shm_in_char); InputT* my_shared; int batch_size = (shm_size / (blockDim.x / 32) - async_copy_align) / input_stride; // indices batch size in lines