Skip to content

Commit

Permalink
[fx2trt]add support for torch.tile (pytorch#66016)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#66016

Add acc_ops.tile and converter for it.

Test Plan: buck test mode/dev-nosan caffe2/torch/fb/fx2trt:test_tile

Reviewed By: wushirong

Differential Revision: D30587939

fbshipit-source-id: 1e2613cfca486fe54fcc0d38e5c7cdeb7d0ed4a0
  • Loading branch information
842974287 authored and facebook-github-bot committed Oct 1, 2021
1 parent 060e41e commit f85d742
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
70 changes: 70 additions & 0 deletions torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,76 @@ def get_softmax_dim(ndim):
layer.name = name
return layer.get_output(0)

@tensorrt_converter(acc_ops.tile)
def acc_ops_tile(network, target, args, kwargs, name):
input_val = kwargs["input"]

if not isinstance(input_val, trt.tensorrt.ITensor):
raise RuntimeError(
f"tile received input {input_val} that is not part "
"of the TensorRT region!"
)

dims = kwargs["dims"]
n_input_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)

if len(dims) > n_input_dims:
assert not network.has_implicit_batch_dimension
layer = network.add_shuffle(input_val)
layer.name = f"{name}_reshape"
num_preceding_ones = len(dims) - n_input_dims

if len(get_dynamic_dims(input_val.shape)) > 1:
input_shape_layer = network.add_shape(input_val)
input_shape_layer.name = f"{name}_input_shape"
preceding_ones = network.add_constant(
(num_preceding_ones,), np.ascontiguousarray([1] * num_preceding_ones, np.int32)
).get_output(0)
reshape_layer = network.add_concatenation([preceding_ones, input_shape_layer.get_output(0)])
reshape_layer.axis = 0
reshape_layer.name = f"{name}_reshape_dims"
layer.set_input(1, reshape_layer.get_output(0))
else:
layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple(input_val.shape)
input_val = layer.get_output(0)
else:
dims = (1,) * (n_input_dims - len(dims)) + dims

if network.has_implicit_batch_dimension:
assert dims[0] == 1, "Can't tile the batch dim when it's implicit."
dims = dims[1:]

starts = [0] * len(dims)
shapes = [i * j for i, j in zip(input_val.shape, dims)]
# If there's dynmaic dim then there would be negative dims in shapes which is not allowed.
# Here we build a dummy shapes array.
if has_dynamic_shape(input_val.shape):
shapes = [1] * len(dims)
strides = [1] * len(dims)
layer = network.add_slice(input_val, starts, shapes, strides)
layer.mode = trt.SliceMode.WRAP
layer.name = name

if has_dynamic_shape(input_val.shape):
starts_tensor = network.add_constant(
(len(dims),), np.ascontiguousarray([0] * len(dims), np.int32)
).get_output(0)
dims_tensor = network.add_constant(
(len(dims),), np.ascontiguousarray(dims, np.int32)
).get_output(0)
input_shape_layer = network.add_shape(input_val)
input_shape_layer.name = f"{name}_slice_input_shape"
slice_shapes_tensor = add_binary_elementwise_layer(
network,
input_shape_layer.get_output(0),
dims_tensor,
trt.ElementWiseOperation.PROD,
f"{name}_slice_shapes",
)
layer.set_input(1, starts_tensor)
layer.set_input(2, slice_shapes_tensor)

return layer.get_output(0)

@tensorrt_converter(acc_ops.relu)
def acc_ops_relu(network, target, args, kwargs, name):
Expand Down
4 changes: 4 additions & 0 deletions torch/fx/experimental/fx_acc/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ def add(*, input, other):
def unsqueeze(*, input, dim):
return torch.unsqueeze(**locals())

@register_acc_op_mapping(op_and_target=("call_function", torch.tile))
@register_acc_op
def tile(*, input, dims):
return torch.tile(**locals())

@register_custom_acc_mapper_fn(
op_and_target=("call_function", torch.stack),
Expand Down

0 comments on commit f85d742

Please sign in to comment.