diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu index 336c1c40832b9..f3fe32e10a52b 100644 --- a/paddle/fluid/operators/trace_op.cu +++ b/paddle/fluid/operators/trace_op.cu @@ -14,6 +14,7 @@ #include #include +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/trace_op.h" @@ -50,6 +51,9 @@ class TraceCUDAKernel : public framework::OpKernel { TensorReduce( diag, out, reduce_dims, static_cast(0), cub::Sum(), IdentityFunctor(), stream); + } else { + math::SetConstant functor; + functor(context.device_context(), out, static_cast(0)); } } }; diff --git a/paddle/fluid/operators/trace_op.h b/paddle/fluid/operators/trace_op.h index b7a6e559ed4ef..ca9439cbed97d 100644 --- a/paddle/fluid/operators/trace_op.h +++ b/paddle/fluid/operators/trace_op.h @@ -179,7 +179,7 @@ class TraceKernel : public framework::OpKernel { auto output_dims = out->dims(); - out->mutable_data(context.GetPlace()); + T* out_data = out->mutable_data(context.GetPlace()); const framework::Tensor diag = Diagonal(context, input, offset, dim1, dim2); @@ -191,6 +191,8 @@ class TraceKernel : public framework::OpKernel { auto reduce_dim = Eigen::array({1}); output.device(place) = x.sum(reduce_dim); out->Resize(output_dims); + } else { + std::fill(out_data, out_data + out->numel(), static_cast(0)); } } };