Skip to content

Commit

Permalink
Revert "[numpy] add torch.concatenate, alias of torch.cat (pytorch#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchmergebot committed Sep 14, 2022
1 parent 23b7a5f commit fa7bf3e
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 117 deletions.
17 changes: 0 additions & 17 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,23 +429,6 @@ Tensor concat(TensorList tensors, int64_t dim) {
return at::cat(tensors, dim);
}

// torch.concatenate, alias for torch.cat
Tensor& concatenate_out(TensorList tensors, Dimname dim, Tensor& result) {
return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim));
}

Tensor concatenate(TensorList tensors, Dimname dim) {
return at::cat(tensors, dimname_to_position(tensors[0], dim));
}

Tensor& concatenate_out(TensorList tensors, int64_t dim, Tensor & result) {
return at::cat_out(result, tensors, dim);
}

Tensor concatenate(TensorList tensors, int64_t dim) {
return at::cat(tensors, dim);
}

static bool sizes_match_except(IntArrayRef s1, IntArrayRef s2, int64_t dim_except /* should already be wrapped */) {
if (s1.size() != s2.size()) {
return false;
Expand Down
9 changes: 0 additions & 9 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1209,15 +1209,6 @@

- func: concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)

# torch.concatenate, alias for torch.cat
- func: concatenate(Tensor[] tensors, int dim=0) -> Tensor

- func: concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)

- func: concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor

- func: concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)

- func: block_diag(Tensor[] tensors) -> Tensor
variants: function
dispatch:
Expand Down
1 change: 0 additions & 1 deletion docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ Indexing, Slicing, Joining, Mutating Ops
argwhere
cat
concat
concatenate
conj
chunk
dsplit
Expand Down
5 changes: 1 addition & 4 deletions functorch/test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3277,9 +3277,7 @@ def test():
{torch.float32: tol(atol=1e-04, rtol=1e-02)}, device_type='cuda'),
))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({
xfail('cat'),
}))
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail)
def test_vmap_exhaustive(self, device, dtype, op):
# needs to be fixed
inplace_failure_list = (
Expand All @@ -3295,7 +3293,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('cat'),
xfail('complex'),
xfail('copysign'),
xfail('histogram'),
Expand Down
80 changes: 80 additions & 0 deletions test/test_tensor_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,10 @@ def test_cat_empty_legacy(self, device):
res1 = torch.cat([empty, empty], dim=1)
self.assertEqual(res1, empty)

with self.assertRaisesRegex(RuntimeError,
'non-empty list of Tensors'):
torch.cat([], dim=1)

def test_cat_empty(self, device):
dtype = torch.float32

Expand All @@ -504,10 +508,39 @@ def test_cat_empty(self, device):
res1 = torch.cat([empty, empty], dim=1)
self.assertEqual(res1, empty)

# check non-legacy-behavior (sizes don't match)
empty = torch.randn((4, 0, 31, 32), dtype=dtype, device=device)
self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1))
self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1))

# check non-legacy-behavior (dimensions don't match)
empty = torch.randn((4, 0), dtype=dtype, device=device)
self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1))
self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1))

def test_cat_out(self, device):
x = torch.zeros((0), device=device)
y = torch.randn((4, 6), device=device)

with self.assertRaisesRegex(
RuntimeError,
r"unsupported operation: some elements of the input tensor and "
r"the written-to tensor refer to a single memory location."):
torch.cat([x, y], dim=0, out=x)

with self.assertRaisesRegex(
RuntimeError,
r"unsupported operation: some elements of the input tensor and "
r"the written-to tensor refer to a single memory location."):
torch.cat([x, y], dim=0, out=y)

z = torch.zeros((4, 6), device=device)
with self.assertRaisesRegex(
RuntimeError,
r"unsupported operation: some elements of the input tensor and "
r"the written-to tensor refer to a single memory location."):
torch.cat([y, z], out=z[:2, :])

w = y.view(-1).clone()
a = torch.cat([w[:2], w[4:6]])
b = torch.cat([w[:2], w[4:6]], out=w[6:10])
Expand Down Expand Up @@ -628,11 +661,32 @@ def test_cat_out_memory_format(self, device):

self.assertTrue(res3_cuda.is_contiguous(memory_format=torch.channels_last))

@onlyCUDA
@deviceCountAtLeast(2)
def test_cat_different_devices(self, devices):
cuda0 = torch.randn((3, 3), device=devices[0])
cuda1 = torch.randn((3, 3), device=devices[1])
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.cat((cuda0, cuda1))

with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.cat((cuda0, cuda0), out=cuda1)

@onlyCUDA
def test_cat_stack_cross_devices(self, device):
cuda = torch.randn((3, 3), device=device)
cpu = torch.randn((3, 3), device='cpu')

# cat
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.cat((cuda, cpu))
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.cat((cpu, cuda))

# Stack
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
Expand Down Expand Up @@ -1005,6 +1059,18 @@ def test_cat_big(self, device):
result = torch.cat(concat_list)
self.assertEqual(result.size(0), SIZE1 + SIZE2)

@onlyCPU
def test_cat_bad_input_sizes(self, device):
x = torch.randn(2, 1, device=device)
y = torch.randn(2, 1, 1, device=device)
z = torch.randn(2, 1, 1, device=device)
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z]))

x = torch.randn(2, 1, 2, device=device)
y = torch.randn(2, 1, 1, device=device)
z = torch.randn(2, 2, 1, device=device)
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1))

@onlyCPU
@dtypes(torch.half, torch.double, torch.int)
def test_cat2(self, device, dtype):
Expand All @@ -1028,6 +1094,20 @@ def test_cat2(self, device, dtype):
z = torch.cat([x, y])
self.assertEqual(z.size(), (21, SIZE, SIZE))

self.assertRaises(RuntimeError, lambda: torch.cat([]))
self.assertRaisesRegex(TypeError, 'got None', lambda: torch.cat([x, None]))

@onlyCPU
def test_cat_scalars(self, device):
x = torch.tensor(0, device=device)
y = torch.tensor(1, device=device)
with self.assertRaisesRegex(RuntimeError, 'zero-dimensional.*cannot be concatenated'):
torch.cat([x, y])

def test_zeros_dtype_out_match(self, device):
d = torch.tensor((2, 3), device=device, dtype=torch.double)
self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), device=device, dtype=torch.float32, out=d))

# FIXME: Create an OpInfo-based tensor creation method test that verifies this for all tensor
# creation methods and verify all dtypes and layouts
@dtypes(torch.bool, torch.uint8, torch.int16, torch.int64, torch.float16, torch.float32, torch.complex64)
Expand Down
9 changes: 0 additions & 9 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2310,15 +2310,6 @@ def merge_dicts(*dicts):
""",
)

add_docstr(
torch.concatenate,
r"""
concatenate(tensors, axis=0, out=None) -> Tensor
Alias of :func:`torch.cat`.
""",
)

add_docstr(
torch.ceil,
r"""
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/passes/normalize_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
{aten::true_divide, aten::div},
{aten::true_divide_, aten::div_},
{aten::concat, aten::cat},
{aten::concatenate, aten::cat},
{aten::row_stack, aten::vstack},
{aten::swapdims, aten::transpose},
{aten::swapdims_, aten::transpose_},
Expand Down
1 change: 0 additions & 1 deletion torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.cartesian_prod: lambda *tensors: -1,
torch.cat: lambda tensors, dim=0, out=None: -1,
torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat
torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate
torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
torch.ceil: lambda input, out=None: -1,
torch.celu: lambda input, alhpa=1., inplace=False: -1,
Expand Down
76 changes: 1 addition & 75 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,75 +1871,6 @@ def sample_inputs_stack(op_info, device, dtype, requires_grad, **kwargs):
for dim in range(-1, len(shape) - 1):
yield SampleInput(tensors, args=(dim,))

def error_inputs_cat(op_info, device, **kwargs):

make_arg = partial(make_tensor, device=device, dtype=torch.float32)

# error inputs for more than one element of the written-to tensor refer to a single memory location
yield ErrorInput(SampleInput([make_arg((S, S)), make_arg((S, S))],
kwargs={'out': make_arg((1, S)).expand((2 * S, S))}),
error_regex='unsupported operation')

# error inputs for empty tensors
yield ErrorInput(SampleInput([], kwargs={'dim': 1}),
error_regex='non-empty list of Tensors')

# error inputs for different sizes
yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}),
error_regex='Sizes of tensors must match except in dimension')
yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S, S, L, L))], kwargs={'dim': 1}),
error_regex='Sizes of tensors must match except in dimension')

# error inputs for different dimensions
yield ErrorInput(SampleInput([make_arg((S - 1, 0)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}),
error_regex='Tensors must have same number of dimensions')
yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S - 1, 0))], kwargs={'dim': 1}),
error_regex='Tensors must have same number of dimensions')

# error inputs for same memory locations
x = torch.zeros((0), device=device)
y = torch.randn((4, 6), device=device)

err_msg = "the written-to tensor refer to a single memory location"

yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': x}),
error_regex=err_msg)
yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': y}),
error_regex=err_msg)

z = torch.zeros((4, 6), device=device)
yield ErrorInput(SampleInput((y, z), kwargs={'out': z[:2, :]}),
error_regex=err_msg)

# error inputs for different devices
if torch.device(device).type == 'cuda':
x_cuda = make_tensor((3, 3), device=device, dtype=torch.float32)
y_cpu = make_tensor((3, 3), device='cpu', dtype=torch.float32)
yield ErrorInput(SampleInput((x_cuda, y_cpu)),
error_regex='Expected all tensors to be on the same device')

# error inputs for different input sizes for more than 2 tensors
yield ErrorInput(SampleInput([make_arg((L, 1)), make_arg((L, 1, 1)), make_arg((L, 1, 1))]),
error_regex='Tensors must have same number of dimensions')

yield ErrorInput(SampleInput([make_arg((S, 1, M)), make_arg((S, 1, 1)), make_arg((S, M, 1))],
kwargs={'dim': 1}),
error_regex='Sizes of tensors must match')

# error inputs for None input
yield ErrorInput(SampleInput((make_arg((S, 1, 1)), None)), error_type=TypeError,
error_regex='got None')

# error inputs for zero-dimensional tensors
yield ErrorInput(SampleInput([make_arg(()), make_arg(())]),
error_regex='zero-dimensional.*cannot be concatenated')

# error inputs for different dtype of out tensors
d = make_tensor((2, 3), device=device, dtype=torch.double)
x = make_tensor((2, 3), device=device, dtype=torch.float32)
yield ErrorInput(SampleInput(x, kwargs={'out': d}), error_type=TypeError,
error_regex='invalid combination of arguments')

def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

Expand Down Expand Up @@ -14158,11 +14089,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
)),
OpInfo('cat',
ref=_cat_np,
aliases=('concat', 'concatenate'),
aliases=('concat',),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32),
sample_inputs_func=sample_inputs_cat_concat,
reference_inputs_func=reference_inputs_cat,
error_inputs_func=error_inputs_cat,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -16958,10 +16888,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
"_refs.cat",
torch_opinfo_name="cat",
supports_nvfuser=False,
skips=(
# FIXME: AssertionError: RuntimeError not raised
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
),
PythonRefInfo(
"_refs.chunk",
Expand Down

0 comments on commit fa7bf3e

Please sign in to comment.