Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule GraphTransitionManager #19007

Merged
merged 39 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d1e53e4
save
pengwa Jan 3, 2024
527ccac
save
pengwa Jan 3, 2024
613136a
save
pengwa Jan 3, 2024
96e3d2c
fix all tests
pengwa Jan 4, 2024
b2897a3
fix
pengwa Jan 4, 2024
a01cb88
minor
pengwa Jan 4, 2024
34cdba4
fix
pengwa Jan 4, 2024
8d34f43
fixes
pengwa Jan 5, 2024
d990e0f
fix
pengwa Jan 5, 2024
29c8a98
fix
pengwa Jan 8, 2024
44f9f3f
fixes
pengwa Jan 8, 2024
b33fd93
fix ci
pengwa Jan 8, 2024
a07f21c
fix
pengwa Jan 8, 2024
2168ea7
refine based on review comments
pengwa Feb 21, 2024
92cb745
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 22, 2024
958c837
fix merge
pengwa Feb 22, 2024
2d53141
fix
pengwa Feb 22, 2024
8078d36
fix
pengwa Feb 23, 2024
ea697c0
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 23, 2024
970525b
fix all tests
pengwa Feb 26, 2024
d29e772
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 26, 2024
f45c4b4
minors
pengwa Feb 26, 2024
2c69654
minor
pengwa Feb 26, 2024
c4880c1
fix test
pengwa Feb 27, 2024
302b29e
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 27, 2024
898ead8
fix tests orttraining/orttraining/test/python/orttraining_test_ortmod…
pengwa Feb 27, 2024
a1d1afe
yes, another minor fix
pengwa Feb 27, 2024
1958aed
fix memory efficient grad mangement
pengwa Feb 27, 2024
49cf041
minor
pengwa Feb 27, 2024
e2daf49
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Mar 7, 2024
becb4c5
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Jun 12, 2024
c4ebb6e
lint
pengwa Jun 12, 2024
a526905
fixes
pengwa Jun 12, 2024
cc3871a
fix lints
pengwa Jun 13, 2024
523e63e
minor
pengwa Jun 13, 2024
4737bd4
fix
pengwa Jun 13, 2024
fd2c95a
fix ut
pengwa Jun 13, 2024
870aa30
fix
pengwa Jun 13, 2024
02dee17
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Jun 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix test
  • Loading branch information
pengwa committed Feb 27, 2024
commit c4880c146aeef5948b16a1ef0bf3228f525c256d
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def get_post_processed_model(
True,
self._device,
self._export_mode,
self._logger,
self._export_extra_kwargs,
)

Expand Down Expand Up @@ -389,7 +390,7 @@ def get_post_processed_model(
flatten_inputs = []

# This looks a bit duplicated with `extract_data_and_schema` function, but this might be better to
# defined as a specialized logic which is the counter-part of `parse_inputs_for_onnx_export`, which handles
# defined as a specialized logic that is the counter-part of `parse_inputs_for_onnx_export`, which handles
# args and kwargs separately.
for name, data_accessor in cur_model_info_for_export.onnx_graph_input_data_accessor.items():
d = data_accessor(copied_args, copied_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import copy
import gc
import inspect
import warnings
from collections import OrderedDict, abc
from functools import partial
from logging import Logger
Expand Down Expand Up @@ -244,6 +243,7 @@ def parse_inputs_for_onnx_export(
constant_as_tensor: bool,
device: torch.device,
export_mode: int,
logger: Logger,
export_extra_kwargs: Optional[Dict[str, any]] = None,
) -> ModelInfoForExport:
"""Parses through the model inputs and returns _InputInfo.
Expand Down Expand Up @@ -282,7 +282,7 @@ def _add_dynamic_shape(name, input) -> Dict[str, Dict[int, str]]:
return dynamic_axes

def _warn_of_constant_inputs(data):
warnings.warn(f"Received input of type {type(data)} is treated as a constant by ORT by default.")
logger.info(f"Received input of type {type(data)} is treated as a constant by ORT by default.")

def _add_input(
name: str, input_value, onnx_graph_input_names: List[str], cur_func: Callable, tensor_idx: List[int]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4239,14 +4239,23 @@ def test_hf_save_pretrained():
assert p1.data.ne(p2.data).sum() == 0


def test_ortmodule_string_inputs_are_ignored():
def test_ortmodule_string_inputs_are_ignored(caplog):
pt_model = MyStrNet()
target_str = "Received input of type <class 'str'> which may be treated as a constant by ORT by default."
with pytest.warns(UserWarning, match=target_str):
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO))
x = torch.randn(1, 2)
out = ort_model(x, "hello")
_test_helpers.assert_values_are_close(out, x + 1)
target_str = "Received input of type <class 'str'> is treated as a constant by ORT by default."

ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO))
x = torch.randn(1, 2)
out = ort_model(x, "hello")
_test_helpers.assert_values_are_close(out, x + 1)

found_log = False
for record in caplog.records:
msg = record.getMessage()
if target_str in msg:
found_log = True
break

assert found_log, f"Expected to find log message '{target_str}' in the logs, but didn't find it."


def test_ortmodule_list_input():
Expand Down