From 2cae2907b6c07f83aa6a17ca5b475df574896e7b Mon Sep 17 00:00:00 2001 From: Casper Date: Fri, 16 Feb 2024 16:12:54 +0100 Subject: [PATCH] Add device guard (fix multi-GPU) (#10) --- awq_ext/vllm/moe_alig_block.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/awq_ext/vllm/moe_alig_block.cu b/awq_ext/vllm/moe_alig_block.cu index 63578e5..811cf63 100644 --- a/awq_ext/vllm/moe_alig_block.cu +++ b/awq_ext/vllm/moe_alig_block.cu @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -75,6 +76,10 @@ void moe_alig_block_size( torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { + const at::cuda::OptionalCUDAGuard device_guard_topk_ids(device_of(topk_ids)); + const at::cuda::OptionalCUDAGuard device_guard_sorted(device_of(sorted_token_ids)); + const at::cuda::OptionalCUDAGuard device_guard_experts(device_of(experts_ids)); + const at::cuda::OptionalCUDAGuard device_guard_num_tokens(device_of(num_tokens_post_pad)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); assert(num_experts <= NUM_MAX_EXPERTS); VLLM_DISPATCH_INTEGRAL_TYPES(