Skip to content

Commit

Permalink
Add device guard (fix multi-GPU) (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Feb 16, 2024
1 parent bad253e commit 2cae290
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions awq_ext/vllm/moe_alig_block.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
Expand Down Expand Up @@ -75,6 +76,10 @@ void moe_alig_block_size(
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const at::cuda::OptionalCUDAGuard device_guard_topk_ids(device_of(topk_ids));
const at::cuda::OptionalCUDAGuard device_guard_sorted(device_of(sorted_token_ids));
const at::cuda::OptionalCUDAGuard device_guard_experts(device_of(experts_ids));
const at::cuda::OptionalCUDAGuard device_guard_num_tokens(device_of(num_tokens_post_pad));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
assert(num_experts <= NUM_MAX_EXPERTS);
VLLM_DISPATCH_INTEGRAL_TYPES(
Expand Down

0 comments on commit 2cae290

Please sign in to comment.