Skip to content

Commit

Permalink
Add Gradient for Reciprocal (#16945)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Aug 4, 2023
1 parent 555414f commit e5bb7ab
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 0 deletions.
9 changes: 9 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1994,5 +1994,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetLSTMGradient) {
return {NodeDef(OpDef{"LSTMGrad", kMSDomain, 1}, input_args, output_args, SrcNodeAttributes())};
}

IMPLEMENT_GRADIENT_BUILDER(GetReciprocalGradient) {
// y = 1 / x
// dy/dx = -1 / x^2
// dL/dx = dL/dy * dy/dx = dL/dy * (-1 / x^2)
return {NodeDef("Mul", {O(0), O(0)}, {IA("Square_O0")}),
NodeDef("Neg", {IA("Square_O0")}, {IA("Neg_Square_O0")}),
NodeDef("Mul", {GO(0), IA("Neg_Square_O0")}, {GI(0)})};
}

} // namespace training
} // namespace onnxruntime
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ DECLARE_GRADIENT_BUILDER(GetScatterElementsGradient)
DECLARE_GRADIENT_BUILDER(GetTriluGradient)
DECLARE_GRADIENT_BUILDER(GetFakeQuantGradient)
DECLARE_GRADIENT_BUILDER(GetLSTMGradient)
DECLARE_GRADIENT_BUILDER(GetReciprocalGradient)

DECLARE_GRADIENT_BUILDER(GetExternalGradient)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Trilu", GetTriluGradient);
REGISTER_GRADIENT_BUILDER("FakeQuant", GetFakeQuantGradient);
REGISTER_GRADIENT_BUILDER("LSTMTraining", GetLSTMGradient);
REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient);

REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
};
Expand Down
6 changes: 6 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3027,6 +3027,12 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
}
#endif

TEST(GradientCheckerTest, ReciprocalGrad) {
// Avoid division by 0 by using the transformer.
std::function<float(float)> transformer = [](float x) { return x > 0 ? x + 0.2f : x - 0.2f; };
UnaryOpGradientTest("Reciprocal", kOnnxDomain, 12, nullptr, &transformer);
}

} // namespace test
} // namespace onnxruntime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6140,3 +6140,34 @@ def forward(self, x):
torch.onnx.export.reset_mock()

del os.environ["ORTMODULE_CACHE_DIR"]


def test_reciprocal_gradient():
class ReciprocalModel(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return 1 / x

def run_step(model, x):
prediction = model(x)
loss = prediction.sum()
loss.backward()
return prediction, loss

device = "cuda"
pt_model = ReciprocalModel().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))

pt_x = torch.zeros(3, 224, 224, requires_grad=True, device=device)
with torch.no_grad():
pt_x[pt_x <= 0] -= 0.2
pt_x[pt_x > 0] += 0.2
ort_x = copy.deepcopy(pt_x)

pt_prediction, pt_loss = run_step(pt_model, pt_x)
ort_prediction, ort_loss = run_step(ort_model, ort_x)
_test_helpers.assert_values_are_close(pt_prediction, ort_prediction)
_test_helpers.assert_values_are_close(pt_loss, ort_loss)
_test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad)

0 comments on commit e5bb7ab

Please sign in to comment.