Skip to content

Commit

Permalink
Add gpu implementation of shuffle_batch_op (PaddlePaddle#33938)
Browse files Browse the repository at this point in the history
* add gpu implementation of shuffle batch
test=develop

* add thrust cuda patches
test=develop

* fix macro guard

* fix shuffle batch compile on windows/hip

* fix hip compilation error

* refine CMakeLists.txt

* fix windows compile error

* try to fix windows CI compilation error

* fix windows compilation again

* fix shuffle_batch op test on Windows
  • Loading branch information
sneaxiy committed Jul 6, 2021
1 parent 5085c44 commit c6b6ba1
Show file tree
Hide file tree
Showing 10 changed files with 836 additions and 19 deletions.
1 change: 1 addition & 0 deletions cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,4 @@ endif()
mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD)
mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION)

include(thrust)
2 changes: 2 additions & 0 deletions cmake/hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,5 @@ message(STATUS "HIP library name: ${hip_library_name}")
# set HIP link libs
find_library(ROCM_HIPRTC_LIB ${hip_library_name} HINTS ${HIP_PATH}/lib)
message(STATUS "ROCM_HIPRTC_LIB: ${ROCM_HIPRTC_LIB}")

include(thrust)
24 changes: 24 additions & 0 deletions cmake/thrust.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
function(add_thrust_patches_if_necessary)
set(thrust_detect_file ${PROJECT_BINARY_DIR}/detect_thrust.cu)
file(WRITE ${thrust_detect_file} ""
"#include \"thrust/version.h\"\n"
"#include \"thrust/shuffle.h\"\n"
"#include \"stdio.h\"\n"
"int main() {\n"
" int version = THRUST_VERSION;\n"
" printf(\"%d\", version);\n"
" return 0;\n"
"}\n")

execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}"
"--run" "${thrust_detect_file}"
WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/"
RESULT_VARIABLE nvcc_res ERROR_QUIET)
if(NOT nvcc_res EQUAL 0)
set(thrust_patches "${PADDLE_SOURCE_DIR}/patches/thrust")
message(STATUS "Add thrust patches: ${thrust_patches}")
include_directories(${thrust_patches})
endif()
endfunction()

add_thrust_patches_if_necessary()
10 changes: 10 additions & 0 deletions paddle/fluid/operators/shuffle_batch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ class ShuffleBatchOp : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}

framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "Seed") {
return expected_kernel_type;
}
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
};

class ShuffleBatchOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down
159 changes: 159 additions & 0 deletions paddle/fluid/operators/shuffle_batch_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)

#ifndef _MSC_VER
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/shuffle.h>
#endif

#include "paddle/fluid/operators/shuffle_batch_op.h"
#include "paddle/fluid/platform/for_range.h"

namespace paddle {
namespace operators {

template <typename T, bool kIsForward>
struct ReorderFunctor {
ReorderFunctor(const T *x, const int64_t *shuffle_idx, T *y, int64_t stride)
: x_(x), shuffle_idx_(shuffle_idx), y_(y), stride_(stride) {}

HOSTDEVICE void operator()(int64_t idx) {
auto reorder_idx = shuffle_idx_[idx / stride_] * stride_ + idx % stride_;
if (kIsForward) {
y_[idx] = x_[reorder_idx];
} else {
y_[reorder_idx] = x_[idx];
}
}

private:
const T *x_;
const int64_t *shuffle_idx_;
T *y_;
int64_t stride_;
};

template <typename T>
class ShuffleBatchCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#ifdef _MSC_VER
PADDLE_THROW(platform::errors::Unimplemented(
"GPU shuffle_batch is not supported on Windows yet"));
#else
auto *x = ctx.Input<framework::Tensor>("X");
auto *seed = ctx.Input<framework::Tensor>("Seed");
auto *out = ctx.Output<framework::Tensor>("Out");
auto *shuffleidx = ctx.Output<framework::Tensor>("ShuffleIdx");
auto *seed_out = ctx.Output<framework::Tensor>("SeedOut");

int64_t x_embed_size = x->dims()[x->dims().size() - 1];
int64_t elem_size = 1;
for (int i = 0; i < x->dims().size() - 1; i++) {
elem_size *= x->dims()[i];
}
shuffleidx->Resize(framework::make_ddim({elem_size}));

int64_t seed_int = 0;
if (seed->IsInitialized()) {
const auto &seed_place = seed->place();
if (platform::is_gpu_place(seed_place)) {
// NOTE: We have overwritten GetKernelTypeForVar, so seed_place would
// not be CUDAPlace in practice. This case would only happen in Python
// op_test framework.
framework::Tensor tmp_tensor;
framework::TensorCopySync(*seed, platform::CPUPlace(), &tmp_tensor);
seed_int = *(tmp_tensor.data<int64_t>());
} else {
seed_int = *(seed->data<int64_t>());
}
} else {
seed_int = ctx.Attr<int>("startup_seed");
}

auto *shuffleidx_data = shuffleidx->mutable_data<int64_t>(ctx.GetPlace());

auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
#ifdef PADDLE_WITH_CUDA
const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream());
#else
const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream());
#endif
thrust::random::default_random_engine engine(seed_int);
thrust::counting_iterator<int64_t> cnt_iter(0);
thrust::shuffle_copy(exec_policy, cnt_iter, cnt_iter + elem_size,
thrust::device_pointer_cast(shuffleidx_data), engine);
// TODO(zengjinle): for small data, direct cudaMemcpy may be better
auto *x_data = x->data<T>();
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
ReorderFunctor<T, true> functor(x_data, shuffleidx_data, out_data,
x_embed_size);
platform::ForRange<platform::CUDADeviceContext> for_range(
dev_ctx, elem_size * x_embed_size);
for_range(functor);

auto *seed_out_data = seed_out->mutable_data<int64_t>(
framework::make_ddim({1}), platform::CPUPlace());
*seed_out_data = engine();
#endif
}
};

template <typename T>
class ShuffleBatchGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#ifdef _MSC_VER
PADDLE_THROW(platform::errors::Unimplemented(
"GPU shuffle_batch_grad is not supported on Windows yet"));
#else
const auto *out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
const auto *shuffleidx = ctx.Input<framework::Tensor>("ShuffleIdx");
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));

const auto *out_grad_data = out_grad->data<T>();
const auto *shuffleidx_data = shuffleidx->data<int64_t>();
auto *x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
auto x_embed_size = x_grad->dims()[x_grad->dims().size() - 1];
ReorderFunctor<T, false> functor(out_grad_data, shuffleidx_data,
x_grad_data, x_embed_size);
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
// TODO(zengjinle): for small data, direct cudaMemcpy may be better
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
x_grad->numel());
for_range(functor);
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(shuffle_batch, ops::ShuffleBatchCUDAKernel<float>,
ops::ShuffleBatchCUDAKernel<double>,
ops::ShuffleBatchCUDAKernel<int32_t>,
ops::ShuffleBatchCUDAKernel<int64_t>);

REGISTER_OP_CUDA_KERNEL(shuffle_batch_grad,
ops::ShuffleBatchGradCUDAKernel<float>,
ops::ShuffleBatchGradCUDAKernel<double>,
ops::ShuffleBatchGradCUDAKernel<int32_t>,
ops::ShuffleBatchGradCUDAKernel<int64_t>);
#endif
85 changes: 85 additions & 0 deletions patches/thrust/thrust/detail/shuffle.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2008-2020 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*! \file shuffle.inl
* \brief Inline file for shuffle.h.
*/

#include <thrust/detail/config.h>
#include <thrust/detail/cpp11_required.h>

#if THRUST_CPP_DIALECT >= 2011

#include <thrust/iterator/iterator_traits.h>
#include <thrust/shuffle.h>
#include <thrust/system/detail/generic/select_system.h>
#include <thrust/system/detail/generic/shuffle.h>

namespace thrust {

__thrust_exec_check_disable__
template <typename DerivedPolicy, typename RandomIterator, typename URBG>
__host__ __device__ void shuffle(
const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
RandomIterator first, RandomIterator last, URBG&& g) {
using thrust::system::detail::generic::shuffle;
return shuffle(
thrust::detail::derived_cast(thrust::detail::strip_const(exec)),
first, last, g);
}

template <typename RandomIterator, typename URBG>
__host__ __device__ void shuffle(RandomIterator first, RandomIterator last,
URBG&& g) {
using thrust::system::detail::generic::select_system;

typedef typename thrust::iterator_system<RandomIterator>::type System;
System system;

return thrust::shuffle(select_system(system), first, last, g);
}

__thrust_exec_check_disable__
template <typename DerivedPolicy, typename RandomIterator,
typename OutputIterator, typename URBG>
__host__ __device__ void shuffle_copy(
const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
RandomIterator first, RandomIterator last, OutputIterator result,
URBG&& g) {
using thrust::system::detail::generic::shuffle_copy;
return shuffle_copy(
thrust::detail::derived_cast(thrust::detail::strip_const(exec)),
first, last, result, g);
}

template <typename RandomIterator, typename OutputIterator, typename URBG>
__host__ __device__ void shuffle_copy(RandomIterator first, RandomIterator last,
OutputIterator result, URBG&& g) {
using thrust::system::detail::generic::select_system;

typedef typename thrust::iterator_system<RandomIterator>::type System1;
typedef typename thrust::iterator_system<OutputIterator>::type System2;

System1 system1;
System2 system2;

return thrust::shuffle_copy(select_system(system1, system2), first, last,
result, g);
}

} // namespace thrust

#endif
Loading

0 comments on commit c6b6ba1

Please sign in to comment.