Skip to content

Commit

Permalink
"cherry picked cpp tests" (PaddlePaddle#12182)
Browse files Browse the repository at this point in the history
* "cherry picked cpp tests"

* "cherry picked"

* "cherry picked tests"

* "merge develop branch"
  • Loading branch information
dzhwinter authored Aug 1, 2018
1 parent 8037bdf commit cd81e53
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 21 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context)
cc_test(data_type_test SRCS data_type_test.cc DEPS data_type place tensor)
if(WITH_GPU)
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type device_context)
else()
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/data_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <string>
#include <unordered_map>

using float16 = paddle::platform::float16;

namespace paddle {
namespace framework {

Expand Down Expand Up @@ -53,7 +55,7 @@ static DataTypeMap* InitDataTypeMap() {
RegisterType<cc_type>(retv, proto_type, #cc_type)

// NOTE: Add your customize type here.
RegType(platform::float16, proto::VarType::FP16);
RegType(float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64);
RegType(int, proto::VarType::INT32);
Expand Down
40 changes: 40 additions & 0 deletions paddle/fluid/framework/data_type_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/data_type.h"

#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"

TEST(DataType, float16) {
using paddle::framework::Tensor;
using paddle::platform::CPUPlace;
using paddle::platform::float16;
namespace f = paddle::framework;
f::proto::VarType::Type dtype = f::proto::VarType::FP16;

Tensor tensor;
CPUPlace cpu;
tensor.mutable_data(cpu, f::ToTypeIndex(dtype));

// test fp16 tensor
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16)));

// test fp16 size
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u);

// test debug info
std::string type = "float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
}
7 changes: 7 additions & 0 deletions paddle/fluid/framework/op_kernel_type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ TEST(OpKernelType, ToString) {
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type),
"data_type[float]:data_layout[NCHW]:place[CPUPlace]:library_type["
"CUDNN]");

using CUDAPlace = paddle::platform::CUDAPlace;
OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW,
LibraryType::kCUDNN);
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2),
"data_type[float16]:data_layout[NCHW]:place[CUDAPlace(0)]:library_"
"type[CUDNN]");
}

TEST(OpKernelType, Hash) {
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ static DDim GetDims(const Scope& scope, const std::string& name,
}
}

static std::string GetDtype(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
return "";
}
if (var->IsType<LoDTensor>()) {
return DataTypeToString(ToDataType(var->Get<LoDTensor>().type()));
} else if (var->IsType<SelectedRows>()) {
return DataTypeToString(
ToDataType(var->Get<SelectedRows>().value().type()));
} else {
return "";
}
}

static int GetRowSize(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
Expand Down Expand Up @@ -172,6 +187,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
if (row_size >= 0) {
ss << "[row_size=" << row_size << "]";
}
std::string dtype = GetDtype(*scope, input.second[i]);
ss << ":" << dtype;
ss << "[" << GetDims(*scope, input.second[i], true) << "]";
ss << "(" << GetLoD(*scope, input.second[i]) << ")";
}
Expand Down
15 changes: 15 additions & 0 deletions paddle/fluid/framework/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/fluid/framework/tensor.h"
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/platform/float16.h"

namespace framework = paddle::framework;
namespace platform = paddle::platform;
Expand Down Expand Up @@ -213,3 +214,17 @@ TEST(Tensor, Layout) {
src.set_layout(framework::DataLayout::kAnyLayout);
ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout);
}

TEST(Tensor, FP16) {
using platform::float16;
framework::Tensor src;
float16* src_ptr = src.mutable_data<float16>({2, 3}, platform::CPUPlace());
for (int i = 0; i < 2 * 3; ++i) {
src_ptr[i] = static_cast<float16>(i);
}
EXPECT_EQ(src.memory_size(), 2 * 3 * sizeof(float16));
// EXPECT a human readable error message
// src.data<uint8_t>();
// Tensor holds the wrong type, it holds N6paddle8platform7float16E at
// [/paddle/Paddle/paddle/fluid/framework/tensor_impl.h:43]
}
72 changes: 61 additions & 11 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,35 @@ def get_output():
tensor_to_check_dtype = np.float32
elif tensor_to_check_dtype == core.VarDesc.VarType.FP64:
tensor_to_check_dtype = np.float64
elif tensor_to_check_dtype == core.VarDesc.VarType.FP16:
tensor_to_check_dtype = np.float16
# set delta as np.float16, will automatic convert to float32, float64
delta = np.array(delta).astype(np.float16)
else:
raise ValueError("Not supported data type " + str(
tensor_to_check_dtype))

gradient_flat = np.zeros(shape=(tensor_size, ), dtype=tensor_to_check_dtype)

def __get_elem__(tensor, i):
if tensor_to_check_dtype == np.float32:
if tensor_to_check_dtype == np.float16:
numpy_tensor = np.array(tensor).astype(np.float16)
numpy_tensor = numpy_tensor.flatten()
return numpy_tensor[i]
elif tensor_to_check_dtype == np.float32:
return tensor._get_float_element(i)
else:
return tensor._get_double_element(i)

def __set_elem__(tensor, i, e):
if tensor_to_check_dtype == np.float32:
if tensor_to_check_dtype == np.float16:
numpy_tensor = np.array(tensor).astype(np.float16)
shape = numpy_tensor.shape
numpy_tensor = numpy_tensor.flatten()
numpy_tensor[i] = e
numpy_tensor = numpy_tensor.reshape(shape).view(np.uint16)
tensor.set(numpy_tensor, place)
elif tensor_to_check_dtype == np.float32:
tensor._set_float_element(i, e)
else:
tensor._set_double_element(i, e)
Expand Down Expand Up @@ -133,6 +148,11 @@ def try_call_once(self, data_type):
if not self.call_once:
self.call_once = True
self.dtype = data_type
# See the comment of np_dtype_to_fluid_dtype
# If the input type is uint16, we assume use float16
# for lodtensor dtype.
if self.dtype == np.uint16:
self.dtype == np.float16

def infer_dtype_from_inputs_outputs(self, inputs, outputs):
def infer_dtype(numpy_dict):
Expand Down Expand Up @@ -161,19 +181,25 @@ def feed_var(self, input_vars, place):
for name, np_value in self.inputs[var_name]:
tensor = core.LoDTensor()
if isinstance(np_value, tuple):
tensor.set(np_value[0], place)
tensor.set(
OpTest.np_value_to_fluid_value(np_value[0]), place)
tensor.set_recursive_sequence_lengths(np_value[1])
else:
tensor.set(np_value, place)
tensor.set(
OpTest.np_value_to_fluid_value(np_value), place)
feed_map[name] = tensor
else:
tensor = core.LoDTensor()
if isinstance(self.inputs[var_name], tuple):
tensor.set(self.inputs[var_name][0], place)
tensor.set(
OpTest.np_value_to_fluid_value(self.inputs[var_name][
0]), place)
tensor.set_recursive_sequence_lengths(self.inputs[var_name][
1])
else:
tensor.set(self.inputs[var_name], place)
tensor.set(
OpTest.np_value_to_fluid_value(self.inputs[var_name]),
place)
feed_map[var_name] = tensor

return feed_map
Expand Down Expand Up @@ -307,13 +333,22 @@ def find_actual(target_name, fetch_list):
np.allclose(
actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place) +
str(actual_t) + "\n" + str(expect_t))
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(actual_t))
if isinstance(expect, tuple):
self.assertListEqual(actual.recursive_sequence_lengths(),
expect[1], "Output (" + out_name +
") has different lod at " + str(place))

def _get_places(self):
if self.dtype == np.float16:
if core.is_compiled_with_cuda() and core.op_support_gpu(
self.op_type):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
return [place]
else:
return []
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0))
Expand Down Expand Up @@ -344,9 +379,9 @@ def __assert_is_close(self, numeric_grads, analytic_grads, names,
def err_msg():
offset = np.argmax(diff_mat > max_relative_error)
return ("%s Variable %s max gradient diff %f over limit %f, "
"the first error element is %d, %f, %f") % (
msg_prefix, name, max_diff, max_relative_error,
offset, a.flatten()[offset], b.flatten()[offset])
"the first error element is %d, expected %f, but got %f"
) % (msg_prefix, name, max_diff, max_relative_error,
offset, a.flatten()[offset], b.flatten()[offset])

self.assertLessEqual(max_diff, max_relative_error, err_msg())

Expand Down Expand Up @@ -435,6 +470,21 @@ def np_dtype_to_fluid_dtype(input):
input.dtype = np.uint16
return input

@staticmethod
def fluid_dtype_to_np_dtype(self, dtype):
"""
See above, convert the dtype to normal type.
"""
if dtype == np.uint16:
dtype = np.float16
return dtype

@staticmethod
def np_value_to_fluid_value(input):
if input.dtype == np.float16:
input = input.view(np.uint16)
return input

def _get_gradient(self,
input_to_check,
place,
Expand All @@ -457,7 +507,7 @@ def _get_gradient(self,
if isinstance(place, fluid.CUDAPlace(0)):
use_cuda = True
executor = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=loss.name, main_program=program)
use_cuda=use_cuda, loss_name=loss.name, main_program=prog)
else:
executor = Executor(place)
return map(np.array,
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/test_hsigmoid_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import math
from op_test import OpTest

np.random.seed(100)


def find_latest_set(num):
return 1 + int(math.floor(math.log(num, 2)))
Expand Down
26 changes: 17 additions & 9 deletions python/paddle/fluid/tests/unittests/testsuite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@
from paddle.fluid.op import Operator


def as_lodtensor(np_array, lod, place):
tensor = core.LoDTensor()
tensor.set(np_value, place)
if lod is not None:
tensor.set_recursive_sequence_lengths(lod)
return tensor


def create_op(scope, op_type, inputs, outputs, attrs):
kwargs = dict()

Expand Down Expand Up @@ -69,14 +61,19 @@ def __create_var__(name, var_name):


def set_input(scope, op, inputs, place):
def np_value_to_fluid_value(input):
if input.dtype == np.float16:
input = input.view(np.uint16)
return input

def __set_input__(var_name, var):
if isinstance(var, tuple) or isinstance(var, np.ndarray):
tensor = scope.find_var(var_name).get_tensor()
if isinstance(var, tuple):
tensor.set_recursive_sequence_lengths(var[1])
var = var[0]
tensor._set_dims(var.shape)
tensor.set(var, place)
tensor.set(np_value_to_fluid_value(var), place)
elif isinstance(var, float):
scope.find_var(var_name).set_float(var)
elif isinstance(var, int):
Expand Down Expand Up @@ -104,6 +101,7 @@ def create_var(block, name, np_list, var_proto):
if name not in np_list:
assert var_proto.intermediate, "{} not found".format(name)
else:
# inferece the dtype from numpy value.
np_value = np_list[name]
if isinstance(np_value, tuple):
dtype = np_value[0].dtype
Expand All @@ -116,6 +114,16 @@ def create_var(block, name, np_list, var_proto):
if is_input:
shape = list(np_value.shape)
lod_level = 0
# NOTE(dzhwinter): type hacking
# numpy float16 is binded to paddle::platform::float16
# in tensor_py.h via the help of uint16 datatype. Because
# the internal memory representation of float16 is
# actually uint16_t in paddle. So we use np.uint16 in numpy for
# raw memory, it can pass through the pybind. So in the testcase,
# we feed data use data.view(uint16), but the dtype is float16 in fact.
# The data.view(uint16) means do not cast the data type, but process data as the uint16
if dtype == np.uint16:
dtype = np.float16
return block.create_var(
dtype=dtype, shape=shape, lod_level=lod_level, name=name)

Expand Down

0 comments on commit cd81e53

Please sign in to comment.