Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cmrnorm #854

Merged
merged 17 commits into from
Dec 20, 2016
Merged

Cmrnorm #854

merged 17 commits into from
Dec 20, 2016

Conversation

hedaoyuan
Copy link
Contributor

@hedaoyuan hedaoyuan commented Dec 13, 2016

该PR实现的功能

  1. Fix CPU version of crossMapNormalBwd is not implemented #294
  2. 增加Paddle Function
    两部分合并在一起写是把normal当作一个例子指导后续Matrix API的重构。以后Paddle新增一个算法功能,实现上不再是Matrix增加一个算法API,而是Paddle增加一个算法Function。

Paddle Function的用法见 #892 的描述,对应到该PR,新增cross map normalize Function

  1. 定义CrossMapNormal/CrossMapNormalGrad 实现算法API;
  2. 将API以FunctionBase的形式,封装成CrossMapNormalFunc/CrossMapNormalGradFunc;
  3. CMRProjectionNormLayer::init里面create forward/backward function;
  4. CMRProjectionNormLayer::forward/backward里面调用相应的function;
  5. 在cross_map_normal_op_test.cpp里面增加TEST;


const int start = -((int)sizeX) / 2;
const int end = (int)sizeX + start;
const real ratio = -(real)2 * scale * pow;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we write: -2.0 * scale * pow ?


const int start = -((int)sizeX) / 2;
const int end = (int)sizeX + start;
const real ratio = -(real)2 * scale * pow;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we write: -2.0 * scale * pow ?

Copy link
Collaborator

@wangkuiyi wangkuiyi Dec 13, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we cannot since 2.0 is considered a double-typed value by the compiler, but real is a macro which could evaluate to float.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but const real ratio might convert double into float if real is float.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but const real ratio might convert double into float if real is float. There may be warning.

}

template <>
void CrossMapNormalGrad<DEVICE_TYPE_CPU>::operator()(CpuMatrix& inputsGrad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you write down the math expression here for better understanding the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backward formula is derived from the forward formula.


// NCHW
template <>
void CrossMapNormal<DEVICE_TYPE_CPU>::operator()(CpuMatrix& outputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you write down the math expression here for better understanding the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@tianbingsz tianbingsz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good, I only have minor comments.

@@ -1262,6 +1263,150 @@ TEST(Matrix, MaxOutFwdBwd) {
}
}

void testCrossMapNormalFwd(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test暂时先写这里了,后续会移到function里面。


REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc);
REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc);
#ifndef PADDLE_ONLY_CPU
Copy link
Contributor Author

@hedaoyuan hedaoyuan Dec 15, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续会把PADDLE_ONLY_CPU移到REGISTER_TYPED_FUNC里面判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能有的Function实现会不支持GPU,这个暂时不修改。

@@ -100,6 +101,11 @@ class Layer {
/// Mark input grad in(true) or out(false) of backward function.
std::vector<bool> markInBackward_;

/// Layer forward function
FunctionBase* forward_;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续会修改成std::vector,一个Layer可以包含多个Function。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -100,6 +101,11 @@ class Layer {
/// Mark input grad in(true) or out(false) of backward function.
std::vector<bool> markInBackward_;

/// Layer forward function
FunctionBase* forward_;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续会修改成std::vector,一个Layer可以包含多个Function。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -45,6 +45,23 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
/* the size of inputs for norm-layer is 1 */
CHECK_EQ(config_.inputs_size(), 1);

if (useGpu_) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续会写个宏,简化这段代码。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@tianbingsz
Copy link
Contributor

Very exciting, will follow the work.

forward_->calc(
{Tensor(input->getData(), dims_)},
{Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)},
{});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续是否会改为数据构造的时候就是Tensor呢? 而不是中间转换呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

作为Function的参数的数据结构和用于计算的数据结构是两个,中间会有一次转换。
作为参数的数据结构(能表示多维,多类型的数据)是用来替代现在的paddle::Argument,用于计算的数据结构是继承于paddle::TensorExpression。后面会分别写两个issue来解释这两个事情。

Copy link
Contributor

@tianbingsz tianbingsz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look very good to me, only minor comments. You may ignore some of them for the fast development.


template <>
FuncConfig& FuncConfig::set<size_t>(const std::string& key, size_t v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use count > 0 to indicate there is already a value for the key?
CHECK(valueMap_.count(key) > 0) << "Duplicated value: " << key;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK(true) will pass. There need valueMap_.count(key) == 0 is true.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是 CHECK_EQ(0, valueMap_.count(key)) ,这样万一count返回的不是0,则 glog 可以打印出count返回的值,以便debug。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK ==> CHECK_EQ


template <>
FuncConfig& FuncConfig::set<real>(const std::string& key, real v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as line 35

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK_EQ

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK ==> CHECK_EQ

template <DeviceType Device>
void CrossMapNormal(real* outputs,
real* denoms,
real* inputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use const for inputs, such as const real * denoms? Anyway, if it is too complicated to do, please ignore this comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

*/
template <DeviceType Device>
void CrossMapNormalGrad(real* inputsGrad,
real* inputsValue,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as line 40

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < imageSize) {
const int w = idx % width;
const int h = (idx / width) % height;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check width > 0 and height > 0 here? Seems not so important here though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size_t is unsigned

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsigned could be 0.

template <>
void CrossMapNormal<DEVICE_TYPE_GPU>(real* outputs,
real* denoms,
real* inputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we write like const real* inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


template <>
void CrossMapNormalGrad<DEVICE_TYPE_GPU>(real* inputsGrad,
real* inputsValue,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as line 66.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@tianbingsz tianbingsz merged commit 42e1217 into PaddlePaddle:develop Dec 20, 2016
Copy link
Collaborator

@wangkuiyi wangkuiyi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后知后觉的补充了一些comments。和 #973 里说的是一样的。

template <>
size_t FuncConfig::get<size_t>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK ==> CHECK_NE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK_NE not support type of it.


template <>
FuncConfig& FuncConfig::set<size_t>(const std::string& key, size_t v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK ==> CHECK_EQ


template <>
FuncConfig& FuncConfig::set<real>(const std::string& key, real v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK ==> CHECK_EQ

template <>
real FuncConfig::get<real>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK ==> CHECK_NE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK_NE not support type of it.


} // namespace paddle

using paddle::FunctionCompare;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

header files 不应该有 using directtive,尤其是在 global namespace 里使用using。详情请见 #973

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, see pr #1006


virtual void init(const FuncConfig& config) {}

virtual void calc(const Arguments& inputs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to have detailed comment for this interface function.

wangxicoding pushed a commit to wangxicoding/Paddle that referenced this pull request Dec 9, 2021
* add FewCLUE 9 datasets

* fix a bug for tnews

* Add CI for Ernie text matching

* Add CI for Ernie text matching

* Add CI for Ernie text matching

* fix encoding problem for windows

* update ernie_text_matching ci

* standard GPU id for CI

* standard GPU id for CI

* add simcse

* update train.py

* support multi-card training

* add README.md

Co-authored-by: Zeyu Chen <chenzeyu01@baidu.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants