Skip to content

Commit

Permalink
fix regularizer lod bug (PaddlePaddle#17848)
Browse files Browse the repository at this point in the history
* fix regularizer lod bug; test=develop

* fix exception bug and one_hot expand; test=develop
  • Loading branch information
phlrain authored Jun 10, 2019
1 parent 8062bd5 commit b888a4c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python/paddle/fluid/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import layers
from . import framework
from . import core
from .dygraph import not_support

__all__ = [
'ErrorClipByValue',
Expand Down Expand Up @@ -335,6 +336,7 @@ def _create_operators(self, param, grad):
return param, new_grad


@not_support
def set_gradient_clip(clip, param_list=None, program=None):
"""
To specify parameters that require gradient clip.
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,8 @@ def dtype(self):
@property
def lod_level(self):
# TODO(minqiyang): Support lod_level in dygraph mode
if in_dygraph_mode():
raise Exception("Dygraph model DO NOT supprt lod")
return self.desc.lod_level()

@property
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6576,6 +6576,7 @@ def one_hot(input, depth):
inputs = {'X': input}
attrs = {'depth': depth}
else:
depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth}
attrs = {}
helper.append_op(
Expand Down Expand Up @@ -9383,6 +9384,7 @@ def contain_tensor(expand_times):
new_expand_times = []
for ele in expand_times:
if isinstance(ele, Variable):
ele.stop_gradient = True
new_expand_times.append(ele)
else:
assert (isinstance(ele, int))
Expand Down
14 changes: 10 additions & 4 deletions python/paddle/fluid/regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,11 @@ def __call__(self, param, grad, block):
assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block)

decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
if framework.in_dygraph_mode():
decay = block.create_var(dtype=param.dtype, shape=param.shape)
else:
decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)

# Append Op to calculate decay
block.append_op(
Expand Down Expand Up @@ -231,8 +234,11 @@ def __call__(self, param, grad, block):
assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block)

decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
if framework.in_dygraph_mode():
decay = block.create_var(dtype=param.dtype, shape=param.shape)
else:
decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)

# Append sign op
block.append_op(
Expand Down

0 comments on commit b888a4c

Please sign in to comment.