Skip to content

Commit

Permalink
adjust the dims range to [1,6] and fix some problem
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Jul 3, 2018
1 parent 9ca88fa commit 0cef33a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 18 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/operators/squeeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ class SqueezeOp : public framework::OperatorWithKernel {
"Output(Out) of SqueezeOp should not be null.");

const auto& x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<9).
PADDLE_ENFORCE(x_dims.size() <= 9,
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimnesions, dynamic dimensions must have "
"between [1, 9] dimensions.");
"between [1, 6] dimensions (Eigen limit).");

const auto& axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) {
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/operators/squeeze_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#define EIGEN_USE_GPU

#include "paddle/fluid/operators/squeeze_op.h"

namespace ops = paddle::operators;
Expand Down
26 changes: 13 additions & 13 deletions python/paddle/fluid/tests/unittests/test_squeeze_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self):

self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

def test_check_output(self):
Expand All @@ -46,7 +46,7 @@ def setUp(self):

self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

def test_check_output(self):
Expand All @@ -65,7 +65,7 @@ def setUp(self):

self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

def test_check_output(self):
Expand All @@ -78,13 +78,13 @@ def test_check_grad(self):
# Correct: Just part of axes be squeezed.
class TestSqueezeOp4(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1)
axes = (2, 6)
new_shape = (1, 3, 5, 1, 4)
ori_shape = (3, 1, 5, 1, 4, 1)
axes = (1, -1)
new_shape = (3, 5, 1, 4)

self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

def test_check_output(self):
Expand Down Expand Up @@ -122,7 +122,7 @@ def setUp(self):

self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

def test_check_output(self):
Expand All @@ -141,7 +141,7 @@ def setUp(self):

self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

def test_check_output(self):
Expand All @@ -154,13 +154,13 @@ def test_check_grad(self):
# Correct: Inpalce. Just part of axes be squeezed.
class TestSqueezeOpInplace4(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1)
axes = (2, 6)
new_shape = (1, 3, 5, 1, 4)
ori_shape = (3, 1, 5, 1, 4, 1)
axes = (1, -1)
new_shape = (3, 5, 1, 4)

self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}

def test_check_output(self):
Expand Down

0 comments on commit 0cef33a

Please sign in to comment.