Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Imperative transfer gru unit #16095

Merged
merged 6 commits into from
Mar 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 136 additions & 1 deletion python/paddle/fluid/imperative/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from ..framework import Variable, OpProtoHolder
from ..param_attr import ParamAttr
from ..initializer import Normal, Constant
__all__ = ['Conv2D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding']

__all__ = ['Conv2D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding', 'GRUUnit']


class Conv2D(layers.Layer):
Expand Down Expand Up @@ -468,3 +469,137 @@ def forward(self, input):
})

return out


class GRUUnit(layers.Layer):
"""
**GRU unit layer**

if origin_mode is True, then the equation of a gru step is from paper
`Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_

.. math::
u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u)

r_t & = actGate(xr_{t} + W_r h_{t-1} + b_r)

m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m)

h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t)

if origin_mode is False, then the equation of a gru step is from paper
`Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
Modeling <https://arxiv.org/pdf/1412.3555.pdf>`_

.. math::
u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u)

r_t & = actGate(xr_{t} + W_r h_{t-1} + b_r)

m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m)

h_t & = dot((1-u_t), h_{t-1}) + dot(u_t, m_t)


The inputs of gru unit includes :math:`z_t`, :math:`h_{t-1}`. In terms
of the equation above, the :math:`z_t` is split into 3 parts -
:math:`xu_t`, :math:`xr_t` and :math:`xm_t`. This means that in order to
implement a full GRU unit operator for an input, a fully
connected layer has to be applied, such that :math:`z_t = W_{fc}x_t`.

The terms :math:`u_t` and :math:`r_t` represent the update and reset gates
of the GRU cell. Unlike LSTM, GRU has one lesser gate. However, there is
an intermediate candidate hidden output, which is denoted by :math:`m_t`.
This layer has three outputs :math:`h_t`, :math:`dot(r_t, h_{t-1})`
and concatenation of :math:`u_t`, :math:`r_t` and :math:`m_t`.

Args:
input (Variable): The fc transformed input value of current step.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add name_scope?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll add it in next PR

Copy link
Collaborator Author

@velconia velconia Mar 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

name_scope (str): See base class.
hidden (Variable): The hidden value of gru unit from previous step.
size (integer): The input dimension value.
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weight matrix. Note:

- The shape of the weight matrix is :math:`(T \\times 3D)`, where
:math:`D` is the hidden size.
- All elements in the weight matrix can be divided into two parts.
The first part are weights of the update gate and reset gate with
shape :math:`(D \\times 2D)`, and the second part are weights for
candidate hidden state with shape :math:`(D \\times D)`.

If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias
of GRU.Note that the bias with :math:`(1 \\times 3D)` concatenates
the bias in the update gate, reset gate and candidate calculations.
If it is set to False, no bias will be applied to the update gate,
reset gate and candidate calculations. If it is set to None or one
attribute of ParamAttr, gru_unit will create ParamAttr as
bias_attr. If the Initializer of the bias_attr is not set, the bias
is initialized zero. Default: None.
activation (string): The activation type for cell (actNode).
Default: 'tanh'
gate_activation (string): The activation type for gates (actGate).
Default: 'sigmoid'

Returns:
tuple: The hidden value, reset-hidden value and gate values.
"""

def __init__(self,
name_scope,
size,
param_attr=None,
bias_attr=None,
activation='tanh',
gate_activation='sigmoid',
origin_mode=False,
dtype='float32'):
super(GRUUnit, self).__init__(name_scope)

activation_dict = dict(
identity=0,
sigmoid=1,
tanh=2,
relu=3, )
activation = activation_dict[activation]
gate_activation = activation_dict[gate_activation]

self._dtype = dtype
size = size // 3
# create weight
self._weight = self.create_parameter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking. maybe we should make parameter public so people can use it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't agree more~

attr=param_attr, shape=[size, 3 * size], dtype=dtype)

# create bias
bias_size = [1, 3 * size]
self._bias = self.create_parameter(
attr=bias_attr, shape=bias_size, dtype=dtype, is_bias=True)

def forward(self, input, hidden):
inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': self._weight}
if self._bias:
inputs['Bias'] = self._bias

gate = self._helper.create_variable_for_type_inference(self._dtype)
reset_hidden_pre = self._helper.create_variable_for_type_inference(
self._dtype)
updated_hidden = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type='gru_unit',
inputs=inputs,
outputs={
'Gate': gate,
'ResetHiddenPrev': reset_hidden_pre,
'Hidden': updated_hidden,
},
attrs={
'activation': 2, # tanh
'gate_activation': 1, # sigmoid
})

return updated_hidden, reset_hidden_pre, gate
116 changes: 112 additions & 4 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import time
import itertools
import collections
from collections import defaultdict

import paddle.fluid as fluid
import paddle.fluid.core as core
Expand Down Expand Up @@ -257,8 +258,65 @@ def calc_output(self, place):
outs, _ = self._calc_output(place)
return outs

def _calc_output(self, place, parallel=False, no_check_set=None):
def _create_var_from_numpy(self, value):
if isinstance(value, tuple):
data = value[0]
lod = value[1]
v = fluid.imperative.base.to_variable(value=data)
v._ivar.value().get_tensor().set_recursive_sequence_lengths(lod)
return v
else:
return fluid.imperative.base.to_variable(value)

def _calc_imperative_output(self, place, parallel=False, no_check_set=None):
with fluid.imperative.base.guard(place=place):
block = fluid.default_main_program().global_block()

# prepare input variable
inputs = defaultdict(list)
for name, np_value in six.iteritems(self.inputs):
if not isinstance(np_value, list):
np_value = [np_value]

for i in range(len(np_value)):
inputs[name].append(
self._create_var_from_numpy(np_value[i]))

# prepare output variable
outputs = defaultdict(list)
for name, np_value in six.iteritems(self.outputs):
if not isinstance(np_value, list):
np_value = [np_value]

for i in range(len(np_value)):
value = np_value[i]
if isinstance(value, tuple):
v = block.create_var(
name="%s_out%d" % (name, i),
dtype=value[0].dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
v._ivar.value().get_tensor(
).set_recursive_sequence_lengths(value[1])
else:
v = block.create_var(
name="%s_out%d" % (name, i),
dtype=value.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
outputs[name].append(v)

block.append_op(
type=self.op_type,
inputs=inputs,
outputs=outputs,
attrs=self.attrs)

return outputs

def _calc_output(self, place, parallel=False, no_check_set=None):
program = Program()
block = program.global_block()
self._append_ops(block)
Expand Down Expand Up @@ -305,8 +363,13 @@ def check_output_with_place(self,
place,
atol,
no_check_set=None,
equal_nan=False):
equal_nan=False,
check_imperative=False):
if check_imperative:
imperative_outs = self._calc_imperative_output(
place, no_check_set=no_check_set)
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)

for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs:
continue
Expand All @@ -330,6 +393,10 @@ def find_actual(target_name, fetch_list):
type(sub_out))
for item in sub_out:
sub_out_name, expect = item[0], item[1]
if check_imperative:
imperative_actual = imperative_outs[sub_out_name][0]
imperative_actual_t = np.array(
imperative_actual._ivar.value().get_tensor())
idx = find_actual(sub_out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
Expand All @@ -340,12 +407,31 @@ def find_actual(target_name, fetch_list):
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place))
if check_imperative:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place) + " in imperative mode")
if isinstance(expect, tuple):
self.assertListEqual(
actual.recursive_sequence_lengths(), expect[1],
"Output (" + sub_out_name +
") has different lod at " + str(place))
if check_imperative:
self.assertListEqual(
imperative_actual._ivar.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name + ") has different lod at " +
str(place) + " in imperative mode")
else:
if check_imperative:
imperative_actual = imperative_outs[out_name][0]
imperative_actual_t = np.array(
imperative_actual._ivar.value().get_tensor())
idx = find_actual(out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
Expand All @@ -357,10 +443,27 @@ def find_actual(target_name, fetch_list):
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(actual_t) + " in class " + self.__class__.__name__)
if check_imperative:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(imperative_actual_t) + " in class " +
self.__class__.__name__)
if isinstance(expect, tuple):
self.assertListEqual(actual.recursive_sequence_lengths(),
expect[1], "Output (" + out_name +
") has different lod at " + str(place))
if check_imperative:
self.assertListEqual(
imperative_actual._ivar.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name + ") has different lod at " +
str(place) + " in imperative mode")

def _get_places(self):
if self.dtype == np.float16:
Expand All @@ -383,10 +486,15 @@ def _get_places(self):
places.append(core.CUDAPlace(0))
return places

def check_output(self, atol=1e-5, no_check_set=None, equal_nan=False):
def check_output(self,
atol=1e-5,
no_check_set=None,
equal_nan=False,
check_imperative=False):
places = self._get_places()
for place in places:
self.check_output_with_place(place, atol, no_check_set, equal_nan)
self.check_output_with_place(place, atol, no_check_set, equal_nan,
check_imperative)

def check_output_customized(self, checker):
places = self._get_places()
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_gru_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output(atol=1e-8)
self.check_output(atol=1e-8, check_imperative=True)

def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
Expand Down
41 changes: 41 additions & 0 deletions python/paddle/fluid/tests/unittests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,47 @@ def test_conv2d(self):
self.assertTrue(np.allclose(static_ret, dy_ret._numpy()))
self.assertTrue(np.allclose(static_ret, static_ret2))

def test_gru_unit(self):
lod = [[2, 4, 3]]
D = 5
T = sum(lod[0])
N = len(lod[0])

input = np.random.rand(T, 3 * D).astype('float32')
hidden_input = np.random.rand(T, D).astype('float32')

with self.static_graph():
x = layers.data(name='x', shape=[-1, D * 3], dtype='float32')
hidden = layers.data(name='hidden', shape=[-1, D], dtype='float32')
updated_hidden, reset_hidden_pre, gate = layers.gru_unit(
input=x, hidden=hidden, size=D * 3)
static_ret = self.get_static_graph_result(
feed={'x': input,
'hidden': hidden_input},
fetch_list=[updated_hidden, reset_hidden_pre, gate])

with self.static_graph():
x = layers.data(name='x', shape=[-1, D * 3], dtype='float32')
hidden = layers.data(name='hidden', shape=[-1, D], dtype='float32')
updated_hidden, reset_hidden_pre, gate = layers.gru_unit(
input=x, hidden=hidden, size=D * 3)
gru = nn.GRUUnit('gru', size=D * 3)
updated_hidden, reset_hidden_pre, gate = gru(x, hidden)

static_ret2 = self.get_static_graph_result(
feed={'x': input,
'hidden': hidden_input},
fetch_list=[updated_hidden, reset_hidden_pre, gate])

with self.dynamic_graph():
gru = nn.GRUUnit('gru', size=D * 3)
dy_ret = gru(
base.to_variable(input), base.to_variable(hidden_input))

for i in range(len(static_ret)):
self.assertTrue(np.allclose(static_ret[i], static_ret2[i]))
self.assertTrue(np.allclose(static_ret[i], dy_ret[i]._numpy()))


class TestBook(unittest.TestCase):
def test_fit_a_line(self):
Expand Down