Skip to content

Commit

Permalink
[fx2trt] fix a bug in conversion from negative dim to positive dim (p…
Browse files Browse the repository at this point in the history
…ytorch#68360)

Summary:
Pull Request resolved: pytorch#68360

Added a helper function to do this. Only use `mod` to convert negative dim to positive. Do nothing when it's already positive.

Previously in `getitem` if we are slicing to the very end, we will get the dimension wrong.

Test Plan: Add a unit test

Reviewed By: yinghai, wushirong

Differential Revision: D32432893

fbshipit-source-id: 3c5d6a578d92a15207a5e52802750f9ea7f272a9
  • Loading branch information
842974287 authored and facebook-github-bot committed Nov 15, 2021
1 parent 549e014 commit 0cf46fb
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 17 deletions.
1 change: 1 addition & 0 deletions test/fx2trt/converters/acc_op/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class TestGetitemConverter(AccTestCase):
[
("slice_batch_dim", slice(None, None, None)),
("slice_basic", (slice(None, None, None), slice(0, 3, 2))),
("slice_full", (slice(None, None, None), slice(0, 10, 3))),
("ellipsis", (slice(None, None, None), ..., slice(0, 3, 2))),
(
"slice_all_none",
Expand Down
34 changes: 17 additions & 17 deletions torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
add_unary_layer,
add_activation_layer,
extend_attr_to_tuple,
get_positive_dim,
)


Expand Down Expand Up @@ -214,8 +215,8 @@ def acc_ops_flatten(network, target, args, kwargs, name):
)

num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
start_dim = (kwargs["start_dim"] if "start_dim" in kwargs else 0) % num_dims
end_dim = (kwargs["end_dim"] if "end_dim" in kwargs else -1) % num_dims
start_dim = get_positive_dim(kwargs["start_dim"] if "start_dim" in kwargs else 0, num_dims)
end_dim = get_positive_dim(kwargs["end_dim"] if "end_dim" in kwargs else -1, num_dims)

if network.has_implicit_batch_dimension:
assert start_dim != 0, "Can't flatten batch dimension when it's implicit."
Expand Down Expand Up @@ -443,7 +444,7 @@ def get_softmax_dim(ndim):
if dim is None:
dim = get_softmax_dim(input_ranks)

dim = dim % input_ranks
dim = get_positive_dim(dim, input_ranks)
if network.has_implicit_batch_dimension:
assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
dim -= 1
Expand Down Expand Up @@ -835,7 +836,7 @@ def acc_ops_squeeze(network, target, args, kwargs, name):
# dim, which is a very rare case. For now we just claim not supporting dim=None.
assert dim is not None, "We don't support dim=None right now."

dim = dim % (len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0))
dim = get_positive_dim(dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0))
if network.has_implicit_batch_dimension:
assert dim != 0, "We don't support squeeze batch dim when it's implicit."
dim -= 1
Expand Down Expand Up @@ -900,8 +901,8 @@ def acc_ops_unsqueeze(network, target, args, kwargs, name):
dim = kwargs["dim"]
input_shape = input_val.shape
input_shape_size = len(input_val.shape) + 1 if network.has_implicit_batch_dimension else len(input_val.shape)
if dim < 0:
dim = dim % (input_shape_size + 1)
dim = get_positive_dim(dim, input_shape_size + 1)

if network.has_implicit_batch_dimension:
assert dim != 0
dim -= 1
Expand Down Expand Up @@ -929,7 +930,7 @@ def acc_ops_topk(network, target, args, kwargs, name):

num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
k = kwargs["k"]
dim = (kwargs["dim"] if kwargs["dim"] is not None else -1) % num_dims
dim = get_positive_dim(kwargs["dim"] if kwargs["dim"] is not None else -1, num_dims)
operation = trt.TopKOperation.MAX if kwargs["largest"] else trt.TopKOperation.MIN
layer = network.add_topk(
input_val, operation, k, get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension)
Expand Down Expand Up @@ -1061,7 +1062,7 @@ def acc_ops_slice_tensor(network, target, args, kwargs, name):
"of the TensorRT region!")

ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
dims = [dim % ranks for dim in kwargs["dims"]]
dims = [get_positive_dim(dim, ranks) for dim in kwargs["dims"]]

if network.has_implicit_batch_dimension:
if not len(dims):
Expand Down Expand Up @@ -1262,9 +1263,9 @@ def slice_to_trt_params(py_slice, dim_size):
"""
Convert python slice to TensorRT slice layer parameters.
"""
start = (py_slice.start % dim_size) if py_slice.start else 0
start = get_positive_dim(py_slice.start, dim_size) if py_slice.start else 0
stride = py_slice.step if py_slice.step else 1
stop = (py_slice.stop % dim_size) if py_slice.stop else dim_size
stop = get_positive_dim(py_slice.stop, dim_size) if py_slice.stop else dim_size
size = math.ceil((stop - start) * 1.0 / stride)
return start, size, stride

Expand All @@ -1275,7 +1276,7 @@ def slice_to_trt_params(py_slice, dim_size):
# Raise an error if it's trying to subscript batch dimension unless it's
# slice(None, None, None).
batch_subscript = slices[0]
if batch_subscript != slice(None, None, None):
if batch_subscript not in [slice(None, None, None), slice(0, None, None)]:
raise RuntimeError(
f"{name}: Can't subscript batch dimension when it's implicit. Got {slices}"
)
Expand Down Expand Up @@ -1312,7 +1313,7 @@ def slice_to_trt_params(py_slice, dim_size):
size.append(params[1])
stride.append(params[2])
else:
start.append(s % input_val.shape[i])
start.append(get_positive_dim(s, input_val.shape[i]))
size.append(1)
stride.append(1)
i += 1
Expand Down Expand Up @@ -1416,7 +1417,7 @@ def acc_ops_sigmoid(network, target, args, kwargs, name):
def acc_ops_permute(network, target, args, kwargs, name):
input_val = kwargs["input"]
ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
permutation = [i % ranks for i in kwargs["permutation"]]
permutation = [get_positive_dim(i, ranks) for i in kwargs["permutation"]]

if not isinstance(input_val, trt.tensorrt.ITensor):
raise RuntimeError(
Expand Down Expand Up @@ -1590,11 +1591,11 @@ def acc_ops_chunk(network, target, args, kwargs, name):

if network.has_implicit_batch_dimension:
input_dim_size += 1
dim = dim % input_dim_size
dim = get_positive_dim(dim, input_dim_size)
assert dim != 0, "Can't chunk on batch dim when it's implicit!"
dim -= 1
else:
dim = dim % input_dim_size
dim = get_positive_dim(dim, input_dim_size)

if chunks > input_val.shape[dim]:
warnings.warn(
Expand Down Expand Up @@ -1637,8 +1638,7 @@ def acc_ops_cumsum(network, target, args, kwargs, name):
raise RuntimeError(
"cumsum converter currently doesn't support implicit batch dimension"
)
if dim < 0:
dim = dim % input_dim_size
dim = get_positive_dim(dim, input_dim_size)
loop = network.add_loop()
trip_limit = None
if (input_shape[dim] > 0):
Expand Down
18 changes: 18 additions & 0 deletions torch/fx/experimental/fx2trt/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,24 @@ def get_trt_plugin(
return plugin


def get_positive_dim(dim: int, dim_size: int) -> int:
"""
Given an integer number that represents a dimension in the array,
transform it to a positive integer dim if it's negative. Otherwise, do
nothing.
Args:
dim (int): A integer number that represents a dimension in an array.
dim_size (int): The size of the dimension in the array.
Returns:
A positive integer that represent the same dimension as the given dim.
"""
if dim < 0:
return dim % dim_size
return dim


def set_layer_name(layer: trt.ILayer, target: Target, name: str) -> None:
"""
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
Expand Down

0 comments on commit 0cf46fb

Please sign in to comment.