Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule memory improvement #18924

Merged
merged 17 commits into from
Jan 16, 2024
Prev Previous commit
Next Next commit
minors
  • Loading branch information
pengwa committed Jan 15, 2024
commit 2deee943e58840a20b4f0131967859ef061cbe09
Original file line number Diff line number Diff line change
Expand Up @@ -520,11 +520,6 @@ def _initialize_graph_builder(self):
exported_model,
) = post_processing_enable_mem_efficient_training(exported_model, self._flattened_module.named_parameters())

# if self._runtime_options.run_symbolic_shape_infer:
# exported_model = SymbolicShapeInference.infer_shapes(
# exported_model, auto_merge=True, guess_output_rank=True
# )

# All initializer names along with user inputs are a part of the onnx graph inputs
# since the onnx model was exported with the flag keep_initializers_as_inputs=True
# We need to use the raw exported model here since the graph inputs include both user inputrs and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,8 @@ def forward(self, *inputs, **kwargs):
if self._runtime_options.enable_zero_stage3_support:
self._append_pull_weight_trigger_as_input(kwargs, self._device)

param_to_append_as_onnx_graph_inputs = self._graph_initializers

prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
param_to_append_as_onnx_graph_inputs,
self._graph_initializers,
self._graph_info.user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def post_processing_enable_mem_efficient_training(

"""
trainable_named_params = get_params_connected_to_pull_param_trigger(named_params, exported_model)
# print(exported_model.graph.input)
if len(trainable_named_params) == 0:
return False, exported_model

Expand Down Expand Up @@ -103,7 +102,7 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
),
graph_input.name, # Second param is a string, which represent the param_name
graph_input.name, # Second param is a string, which represents the param_name
]

node_outputs = [
Expand All @@ -126,7 +125,6 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
input_offset += 1

# Delete exported_model.graph.input

names_to_remove = [input.name for input in graph_inputs_to_remove]
value_infos_to_remove = [
value_info for value_info in exported_model.graph.value_info if value_info.name in names_to_remove
Expand All @@ -138,7 +136,7 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
exported_model.graph.input.remove(input_to_remove)

# Re-order graph input to make sure the weight pull trigger is the first user input.
offset = 0 # Find the first trainable param, insert the new input before it, as part of user inputs.
offset = 0 # Find the first trainable param, and insert the new input before it, as part of user inputs.
for input in exported_model.graph.input:
if input.name in named_params:
break
Expand Down