Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule GraphTransitionManager #19007

Merged
merged 39 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d1e53e4
save
pengwa Jan 3, 2024
527ccac
save
pengwa Jan 3, 2024
613136a
save
pengwa Jan 3, 2024
96e3d2c
fix all tests
pengwa Jan 4, 2024
b2897a3
fix
pengwa Jan 4, 2024
a01cb88
minor
pengwa Jan 4, 2024
34cdba4
fix
pengwa Jan 4, 2024
8d34f43
fixes
pengwa Jan 5, 2024
d990e0f
fix
pengwa Jan 5, 2024
29c8a98
fix
pengwa Jan 8, 2024
44f9f3f
fixes
pengwa Jan 8, 2024
b33fd93
fix ci
pengwa Jan 8, 2024
a07f21c
fix
pengwa Jan 8, 2024
2168ea7
refine based on review comments
pengwa Feb 21, 2024
92cb745
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 22, 2024
958c837
fix merge
pengwa Feb 22, 2024
2d53141
fix
pengwa Feb 22, 2024
8078d36
fix
pengwa Feb 23, 2024
ea697c0
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 23, 2024
970525b
fix all tests
pengwa Feb 26, 2024
d29e772
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 26, 2024
f45c4b4
minors
pengwa Feb 26, 2024
2c69654
minor
pengwa Feb 26, 2024
c4880c1
fix test
pengwa Feb 27, 2024
302b29e
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 27, 2024
898ead8
fix tests orttraining/orttraining/test/python/orttraining_test_ortmod…
pengwa Feb 27, 2024
a1d1afe
yes, another minor fix
pengwa Feb 27, 2024
1958aed
fix memory efficient grad mangement
pengwa Feb 27, 2024
49cf041
minor
pengwa Feb 27, 2024
e2daf49
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Mar 7, 2024
becb4c5
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Jun 12, 2024
c4ebb6e
lint
pengwa Jun 12, 2024
a526905
fixes
pengwa Jun 12, 2024
cc3871a
fix lints
pengwa Jun 13, 2024
523e63e
minor
pengwa Jun 13, 2024
4737bd4
fix
pengwa Jun 13, 2024
fd2c95a
fix ut
pengwa Jun 13, 2024
870aa30
fix
pengwa Jun 13, 2024
02dee17
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Jun 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix merge
  • Loading branch information
pengwa committed Feb 22, 2024
commit 958c837037d1e5d9aa07454416927a61a19c22c7
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ORTModelInputOutputSchemaType,
ORTModelInputOutputType,
PrimitiveType,
onnx_dtype_to_pytorch_dtype,
unflatten_data_using_schema,
)

Expand All @@ -38,9 +39,12 @@ class ExportedModelInfo:
"""Encapsulates the information of the exported model.

After ONNX model export, the model info is collected and encapsulated in this class, including:
1. The ONNX graph inputs
1. The ONNX graph input names.
2. Graph input requiring gradient information.
3. The model's forward function signature and args/kwargs schema.
3. The model's forward function signature and args/kwargs schema, used as a cache key to compare with the current
inputs to see if the model needs to be re-exported.

This data structure is returned by the GraphTransitionManager._export_model method.

"""

Expand Down Expand Up @@ -120,6 +124,7 @@ def __init__(
module_forward_output_schema: ORTModelInputOutputSchemaType,
post_export_processed_model: onnx.ModelProto,
onnx_graph_input_data_accessor: dict[str, callable],
enable_mem_efficient_grad_management: bool,
):
self._flattened_module = flatten_module

Expand Down Expand Up @@ -152,11 +157,13 @@ def __init__(
# For i-th input name, we can use the i-th function to get the input data from args and kwargs.
self.onnx_graph_input_data_accessor: dict[str, callable] | None = onnx_graph_input_data_accessor

self._enable_mem_efficient_grad_management = enable_mem_efficient_grad_management

# Used for unflattening the outputs from the ORT forward run.
self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema

# A buffer to hold the inputs for the ORT forward run. For performance, we reuse the same buffer for each run.
self._buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict()
self._buffer_for_ort_runs: dict[str, torch.Tensor] | None = None

def __str__(self):
return f"""PostExportProcessedModelInfo class:
Expand All @@ -165,7 +172,7 @@ def __str__(self):
\tonnx_graph_input_dynamic_axes_map: {self.onnx_graph_input_dynamic_axes_map}
\tonnx_graph_input_names_user_defined: {self.onnx_graph_input_names_user_defined}
\tonnx_graph_input_names_require_grad_user_defined: {self.onnx_graph_input_names_require_grad_user_defined}
\tbuffer_for_ort_runs.keys(): {self._buffer_for_ort_runs.keys()}
\tbuffer_for_ort_runs.keys(): {self._buffer_for_ort_runs.keys() if self._buffer_for_ort_runs else None}
"""

def __repr__(self):
Expand All @@ -182,12 +189,27 @@ def construct_inputs(

The inputs are constructed in the order they appear in the model's forward function signature
"""
from ._mem_efficient_grad_mgmt import (
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)

# First time construct the buffer for the ORT forward run.
if len(self._buffer_for_ort_runs) == 0:
if self._buffer_for_ort_runs is None:
self._buffer_for_ort_runs = OrderedDict()

# Create the buffers for the inputs that are either parameters or buffers in the original module.
# For user inputs, fill with None for now, and will be filled dynamically during the forward run.
parameter_names = {k: v for k, v in self._flattened_module.named_parameters()}

if self._enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

parameter_names = get_params_not_connected_to_pull_param_trigger(
self._flattened_module.named_parameters(), self._post_export_processed_model
)
else:
parameter_names = {k: v for k, v in self._flattened_module.named_parameters()}
buffer_names = {k: v for k, v in self._flattened_module.named_buffers()}
for input_name in self.onnx_graph_input_names:
if input_name in parameter_names:
Expand All @@ -198,6 +220,14 @@ def construct_inputs(
self._buffer_for_ort_runs[input_name] = None

for name in self.onnx_graph_input_names_user_defined:
if self._enable_mem_efficient_grad_management and name == MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME:
self._buffer_for_ort_runs[name] = torch.zeros(
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()
continue

if name in self.onnx_graph_input_data_accessor:
assert name in self._buffer_for_ort_runs, f"{name} is not in buffer_for_ort_runs"
data = self.onnx_graph_input_data_accessor[name](args, kwargs)
Expand Down Expand Up @@ -316,26 +346,26 @@ def get_post_processed_model(
#
# The _io.FlattenedModule serves as a module wrapper designed to support tuple inputs and outputs for
# PyTorch run during ONNX export. (Remember the PyTorch exporter handles tuple inputs and outputs better.)
# Internally, it facilitates the acceptance of tuple inputs and generation of tuple outputs by invoking
# Internally, it facilitates the acceptance of tuple inputs and the generation of tuple outputs by invoking
# the original module's forward function. The workflow involves the following steps:

# 1. Prior to export, both args and kwargs are flattened into a 1-D tensor list, and a schema for the
# flattened args and kwargs is generated. This schema is essential for the subsequent unflattening
# 1. Prior to export, both args and kwargs are flattened into a 1-D tensor list, and schemas for the
# flattened args and kwargs are generated. This schemas are essential for the subsequent un-flattening
# process.

# 2. The flattened inputs (args + kwargs) are passed to the _io.FlattenedModule's forward run.

# 3. The args schema and kwargs schema, etc are conveyed to the _io.FlattenedModule by setting the
# corresponding attributes.

# 4. Within the _io.FlattenedModule's forward run, the inputs are unflattened to the original args and
# 4. Within the _io.FlattenedModule's forward run, the inputs are un-flattened to the original args and
# kwargs using the associated schemas, and then they are passed to the original module's forward function.

# 5. Upon the completion of the forward function, the outputs from the original module are flattened and
# returned to the caller.

# 6. The 1-D flattened output tensors retain the same order as the outputs from the ONNX Runtime (ORT)
# forward run. To facilitate unflattening during subsequent ORT runs, the output schema is saved as
# forward run. To facilitate un-flattening during subsequent ORT runs, the output schema is saved as
# an attribute named `_output_schema` in the _io.FlattenedModule.

copied_args = copy.copy(args)
Expand Down Expand Up @@ -446,6 +476,8 @@ def get_post_processed_model(
enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support,
run_symbolic_shape_infer=self._runtime_options.run_symbolic_shape_infer,
stage3_param_handle=self,
enable_mem_efficient_grad_management=self._export_mode != torch.onnx.TrainingMode.EVAL
and self._runtime_options.enable_mem_efficient_grad_management,
logger=self._logger,
)

Expand All @@ -465,7 +497,7 @@ def get_post_processed_model(

@staticmethod
def _export_check(
prev_exported_model_info: ExportedModelInfo,
prev_exported_model_info: ExportedModelInfo | None,
original_model_has_changed: bool,
cur_args_schema: ORTModelInputOutputSchemaType,
cur_kwargs_schema: ORTModelInputOutputSchemaType,
Expand Down Expand Up @@ -549,13 +581,12 @@ def _post_export_process(
enable_zero_stage3_support: bool,
run_symbolic_shape_infer: bool,
stage3_param_handle: type,
enable_mem_efficient_grad_management: bool,
logger: logging.Logger,
):
"""Post process the exported model, generate the processed model which will be used for initializing graph builder."""

# Deepcopy the exported model, in case modification affects the exported model.

# TODO(): Do pre-grad graph modification as needed, for memory-efficient gradient management, etc.
post_processed_model = copy.deepcopy(exported_model_info.exported_model)

if enable_custom_autograd_function:
Expand All @@ -578,16 +609,46 @@ def _post_export_process(
[name for name, _ in flatten_module.named_parameters()],
)

onnx_graph_input_names_user_defined = copy.deepcopy(exported_model_info.onnx_graph_input_names_user_defined)
onnx_graph_input_names_require_grad_user_defined = copy.deepcopy(
exported_model_info.onnx_graph_input_names_require_grad_user_defined
)
onnx_graph_input_names = copy.deepcopy(exported_model_info.onnx_graph_input_names)
onnx_graph_input_names_require_grad = copy.deepcopy(exported_model_info.onnx_graph_input_names_require_grad)
if enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training

# Override the options if model is not modified.
(
enable_mem_efficient_grad_management,
post_processed_model,
) = post_processing_enable_mem_efficient_training(post_processed_model, flatten_module.named_parameters())

if enable_custom_autograd_function:
from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME

# Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph.
onnx_graph_input_names_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)
onnx_graph_input_names_require_grad_user_defined.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)
onnx_graph_input_names.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)
onnx_graph_input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)

if run_symbolic_shape_infer:
post_processed_model = SymbolicShapeInference.infer_shapes(
post_processed_model, auto_merge=True, guess_output_rank=True
)

post_export_processed_model_info = PostExportProcessedModelInfo(
flatten_module,
exported_model_info.onnx_graph_input_names_user_defined,
exported_model_info.onnx_graph_input_names_require_grad_user_defined,
exported_model_info.onnx_graph_input_names,
exported_model_info.onnx_graph_input_names_require_grad,
onnx_graph_input_names_user_defined,
onnx_graph_input_names_require_grad_user_defined,
onnx_graph_input_names,
onnx_graph_input_names_require_grad,
model_info_for_export.onnx_graph_input_dynamic_axes_map,
exported_model_info.module_forward_output_schema,
post_processed_model,
model_info_for_export.onnx_graph_input_data_accessor,
enable_mem_efficient_grad_management,
)

logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_params_not_connected_to_pull_param_trigger(
):
# Be noted, some parameters might not in graph input because they are not used in forward, so we filtered them also.
onnx_initializer_names = {p.name for p in exported_model.graph.input}
return [v for k, v in named_params if not v.requires_grad and k in onnx_initializer_names]
return {k: v for k, v in named_params if not v.requires_grad and k in onnx_initializer_names}


def post_processing_enable_mem_efficient_training(
Expand Down
Loading