Skip to content

Commit

Permalink
Add fp16 backward support (PaddlePaddle#14202)
Browse files Browse the repository at this point in the history
* add fp16 backward support
test=develop

* add sum_op fp16 test

* disable test_dist_save_load
test=develop

* add check_grad for sum

* add unit test for softmax_grad fp16
test=develop

* add scale_op unit test

* add mul_grad_op unit test for fp16

* add cross_entropy_grad and eman_grad unit test for fp16
test=develop

* fix cross_entropy unit test

* add pool2d fp16 unit test

* refine conv2d fp16 unit test
test=develop

* refine activation unit test
test=develop

* fix ci
test=develop

* follow zhihong's comment, copy from PaddlePaddle#12796
test=develop
  • Loading branch information
chengduo committed Nov 7, 2018
1 parent 0953cd3 commit a9b5d42
Show file tree
Hide file tree
Showing 31 changed files with 767 additions and 961 deletions.
4 changes: 3 additions & 1 deletion paddle/fluid/operators/activation_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace plat = paddle::platform;
act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>);
ops::grad_functor<double>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);

FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);
5 changes: 2 additions & 3 deletions paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,7 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
const Out out_conj = Eigen::numext::conj(out);
dx.device(d) = static_cast<T>(0.5) * dout / out_conj;
dx.device(d) = static_cast<T>(0.5) * dout / out;
}
};

Expand Down Expand Up @@ -740,7 +739,7 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor - static_cast<T>(1)));
x.pow(static_cast<T>(factor) - static_cast<T>(1));
}
};

Expand Down
21 changes: 12 additions & 9 deletions paddle/fluid/operators/batch_norm_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());

auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if ((N * H * W * D) == 1) {
Expand Down Expand Up @@ -272,19 +272,21 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>

const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data = saved_mean->template data<T>();
const void *saved_var_data = saved_var->template data<T>();
const void *saved_mean_data =
saved_mean->template data<BatchNormParamType<T>>();
const void *saved_var_data =
saved_var->template data<BatchNormParamType<T>>();

CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(),
d_scale->template mutable_data<T>(ctx.GetPlace()),
d_bias->template mutable_data<T>(ctx.GetPlace()), epsilon,
saved_mean_data, saved_var_data));
scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));

// clean when exit.
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
Expand All @@ -304,4 +306,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, double>);
ops::BatchNormGradKernel<plat::CUDADeviceContext, double>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
5 changes: 4 additions & 1 deletion paddle/fluid/operators/conv_cudnn_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
// Currently tensor core is only enabled using this algo
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
VLOG(5) << "use cudnn_tensor_op_math";
} else {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
}
#endif

Expand Down Expand Up @@ -361,7 +363,8 @@ REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>);
paddle::operators::CUDNNConvGradOpKernel<double>,
paddle::operators::CUDNNConvGradOpKernel<plat::float16>);

REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
Expand Down
13 changes: 9 additions & 4 deletions paddle/fluid/operators/cross_entropy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/cross_entropy_op.h"
#include "paddle/fluid/platform/float16.h"

namespace plat = paddle::platform;
namespace ops = paddle::operators;
using CUDACtx = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(cross_entropy,
ops::CrossEntropyOpKernel<CUDACtx, float>,
ops::CrossEntropyOpKernel<CUDACtx, double>);
REGISTER_OP_CUDA_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpKernel<CUDACtx, float>,
ops::CrossEntropyGradientOpKernel<CUDACtx, double>);
ops::CrossEntropyOpKernel<CUDACtx, double>,
ops::CrossEntropyOpKernel<CUDACtx, plat::float16>);

REGISTER_OP_CUDA_KERNEL(
cross_entropy_grad, ops::CrossEntropyGradientOpKernel<CUDACtx, float>,
ops::CrossEntropyGradientOpKernel<CUDACtx, double>,
ops::CrossEntropyGradientOpKernel<CUDACtx, plat::float16>);
3 changes: 2 additions & 1 deletion paddle/fluid/operators/elementwise_add_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>);
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>);
4 changes: 2 additions & 2 deletions paddle/fluid/operators/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
T val = 0;
T val(0);

do {
int x_offset = i * w + j;
Expand Down Expand Up @@ -433,7 +433,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
int tid = threadIdx.x;
int j = blockIdx.x;

T val = 0;
T val(0);
int ttid = tid;

while (true) {
Expand Down
22 changes: 16 additions & 6 deletions paddle/fluid/operators/math/cross_entropy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ namespace operators {
namespace math {

namespace {

__device__ __forceinline__ float real_log(float x) { return logf(x); }

__device__ __forceinline__ double real_log(double x) { return log(x); }

__device__ __forceinline__ platform::float16 real_log(
const platform::float16& val) {
return static_cast<platform::float16>(hlog(static_cast<half>(val)));
}

template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
const int N, const int D,
Expand All @@ -29,21 +39,21 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D || label[i] == ignore_index);
Y[i] = ignore_index == label[i]
? 0
: -math::TolerableValue<T>()(log(X[i * D + label[i]]));
? static_cast<T>(0)
: -math::TolerableValue<T>()(real_log(X[i * D + label[i]]));
}
}

template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
T val = 0;
T val(0);

int idx = blockIdx.x * class_num + tid;
int end = blockIdx.x * class_num + class_num;
for (; idx < end; idx += blockDim.x) {
val += math::TolerableValue<T>()(std::log(X[idx])) * label[idx];
val += math::TolerableValue<T>()(real_log(X[idx])) * label[idx];
}

val = paddle::platform::reduceSum(val, tid, blockDim.x);
Expand All @@ -53,8 +63,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
}
} // namespace

using Tensor = framework::Tensor;

template <typename T>
class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
public:
Expand Down Expand Up @@ -89,6 +97,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {

template class CrossEntropyFunctor<platform::CUDADeviceContext, float>;
template class CrossEntropyFunctor<platform::CUDADeviceContext, double>;
template class CrossEntropyFunctor<platform::CUDADeviceContext,
platform::float16>;
} // namespace math
} // namespace operators
} // namespace paddle
21 changes: 21 additions & 0 deletions paddle/fluid/operators/math/cross_entropy.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <limits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/hostdevice.h"
Expand All @@ -33,6 +34,26 @@ struct TolerableValue {
}
};

// NOTE(dzh): float16 value clip behave different.
// 1. Our ValueClipping has a hardcore threshold 1e20
// for float number. 1e20 will resulting in overflow in float16.
// 2. float16 should expose the the real number overflow to python.
// because mixed-training depends the inf/nan value to determine
// if the scale value will be adjusted.
// Also. In standard implementation of cross entropy, other
// framework not has the ValueClipping.
template <>
struct TolerableValue<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& x) const {
if (platform::isfinite(x))
return x;
else if (x > static_cast<platform::float16>(0))
return std::numeric_limits<platform::float16>::max();
else
return std::numeric_limits<platform::float16>::min();
}
};

template <typename DeviceContext, typename T>
class CrossEntropyFunctor {
public:
Expand Down
15 changes: 12 additions & 3 deletions paddle/fluid/operators/math/selected_rows_functor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -118,7 +119,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
auto* out_data = output->data<T>();

SetConstant<platform::CUDADeviceContext, T> functor;
functor(context, output, 0.0);
functor(context, output, static_cast<T>(0));

const int block_size = 256;
dim3 threads(block_size, 1);
Expand All @@ -136,6 +137,9 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {

template struct SelectedRowsAddTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAdd<platform::CUDADeviceContext, platform::float16>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext,
platform::float16>;

template <typename T>
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
Expand Down Expand Up @@ -175,6 +179,8 @@ template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int64_t>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext,
platform::float16>;

namespace {
template <typename T, int block_size>
Expand Down Expand Up @@ -227,6 +233,8 @@ template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext,
platform::float16>;

namespace scatter {

Expand Down Expand Up @@ -287,7 +295,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
context.GetPlace());

math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
constant_functor(context, out.mutable_value(), static_cast<T>(0));

auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
Expand Down Expand Up @@ -347,7 +355,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
context.GetPlace());

math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
constant_functor(context, out.mutable_value(), static_cast<T>(0));

auto* out_data = out.mutable_value()->data<T>();

Expand All @@ -374,6 +382,7 @@ template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;

template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/math/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,15 @@ template class SoftmaxCUDNNFunctor<float>;
template class SoftmaxCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<float>;
template class SoftmaxGradCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<platform::float16>;

template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16>;
template class SoftmaxFunctor<platform::CUDADeviceContext, float>;
template class SoftmaxFunctor<platform::CUDADeviceContext, double>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, double>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext,
platform::float16>;

} // namespace math
} // namespace operators
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/operators/mean_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ limitations under the License. */
#define EIGEN_USE_GPU

#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/platform/float16.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
mean, ops::MeanKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanKernel<paddle::platform::CUDADeviceContext, double>);
ops::MeanKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
mean_grad, ops::MeanGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::MeanGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanGradKernel<paddle::platform::CUDADeviceContext, plat::float16>);
3 changes: 1 addition & 2 deletions paddle/fluid/operators/mean_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ class MeanGradKernel : public framework::OpKernel<T> {
IG->mutable_data<T>(context.GetPlace());

T ig_size = static_cast<T>(IG->numel());
Eigen::DSizes<int, 1> bcast(ig_size);

Eigen::DSizes<int, 1> bcast(static_cast<int>(ig_size));
EigenVector<T>::Flatten(*IG).device(
*context.template device_context<DeviceContext>().eigen_device()) =
(EigenVector<T>::From(*OG) / ig_size).broadcast(bcast);
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/mul_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel<plat::CUDADeviceContext, float>,
ops::MulKernel<plat::CUDADeviceContext, double>,
ops::MulKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(mul_grad,
ops::MulGradKernel<plat::CUDADeviceContext, float>,
ops::MulGradKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
mul_grad, ops::MulGradKernel<plat::CUDADeviceContext, float>,
ops::MulGradKernel<plat::CUDADeviceContext, double>,
ops::MulGradKernel<plat::CUDADeviceContext, plat::float16>);
3 changes: 2 additions & 1 deletion paddle/fluid/operators/pool_cudnn_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace,
ops::PoolCUDNNOpKernel<plat::float16>);
REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace,
ops::PoolCUDNNGradOpKernel<float>,
ops::PoolCUDNNGradOpKernel<double>);
ops::PoolCUDNNGradOpKernel<double>,
ops::PoolCUDNNGradOpKernel<plat::float16>);

REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace,
ops::PoolCUDNNOpKernel<float>,
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/operators/scale_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/scale_op.h"
#include "paddle/fluid/platform/float16.h"
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
scale,
paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext, double>,
paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext, int>,
paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext,
int64_t>);
int64_t>,
paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
3 changes: 2 additions & 1 deletion paddle/fluid/operators/softmax_cudnn_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,5 @@ REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>,
ops::SoftmaxGradCUDNNKernel<double>);
ops::SoftmaxGradCUDNNKernel<double>,
ops::SoftmaxGradCUDNNKernel<plat::float16>);
Loading

0 comments on commit a9b5d42

Please sign in to comment.