Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ATen support for bicubic interpolation #19380

Merged
merged 3 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"upsample_nearest1d": self._infer_aten_upsample,
"upsample_nearest2d": self._infer_aten_upsample,
"upsample_nearest3d": self._infer_aten_upsample,
"upsample_bicubic2d": self._infer_aten_upsample,
}
self.run_ = True
self.suggested_merge_ = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@


# PyTorch removed related backward functions with "vec" overload name since 1.13. The functions with no overload name
# are available for all versions, though they are not that convienent to use.

Check warning on line 241 in orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "convienent" is a misspelling of "convenient" Raw Output: ./orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py:241:59: "convienent" is a misspelling of "convenient"
def _upsample_gradient(backward_fn, dims):
scales = ["" for _ in range(dims)]
if "bilinear" in backward_fn:
if "bicubic" in backward_fn:
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
scales = ["I(2)", *scales]
return [
("Shape", ["I(0)"], ["Shape_X"]),
Expand Down Expand Up @@ -271,3 +271,8 @@
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec")
def upsample_nearest3d_gradient():
return _upsample_gradient("upsample_nearest3d_backward", 3)


@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec")
def upsample_bicubic2d_gradient():
return _upsample_gradient("upsample_bicubic2d_backward", 2)
Original file line number Diff line number Diff line change
Expand Up @@ -808,3 +808,16 @@ def upsample_nearest2d(g, input, output_size, scale_factors):
@register_symbolic("upsample_nearest3d")
def upsample_nearest3d(g, input, output_size, scale_factors):
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d")


@register_symbolic("upsample_bicubic2d")
def upsample_bicubic2d(g, input, output_size, align_corners, scale_factors):
return g.op(
"org.pytorch.aten::ATen",
input,
output_size,
align_corners,
scale_factors,
operator_s="upsample_bicubic2d",
overload_name_s="vec",
)
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,34 @@ def run_step(model, input):
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)


def test_aten_upsample_bicubic():
class _NeuralNetUpsampleBicubic(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.nn.functional.interpolate(input, size=(8, 12), mode="bicubic")

device = "cuda"
pt_model = _NeuralNetUpsampleBicubic().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))

def run_step(model, input):
prediction = model(input)
prediction.sum().backward()
return prediction

# reset manual seed to reset the generator
torch.manual_seed(2333)
pt_input = torch.randn([2, 4, 6, 8], dtype=torch.float, device=device, requires_grad=True)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input)
ort_prediction = run_step(ort_model, ort_input)

_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)


def test_gradient_correctness_cast_chain():
class NeuralNetCast(torch.nn.Module):
def __init__(self, D):
Expand Down
Loading