Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule memory improvement #18924

Merged
merged 17 commits into from
Jan 16, 2024
Prev Previous commit
Next Next commit
remove stage3 related change
(cherry picked from commit be122e3)
  • Loading branch information
pengwa committed Dec 25, 2023
commit 967d544e7463aa983a1b42b6f82ca03be68a8a8f
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,13 @@ def _initialize_graph_builder(self):

# Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph.
input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME)

if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME

# Add stage3 mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph.
input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)
# Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph.

input_names_require_grad.insert(0, MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)
grad_builder_config.input_names_require_grad = input_names_require_grad
grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING
grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization
Expand Down Expand Up @@ -613,10 +615,20 @@ def _enable_conditional_optimizations(
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
self._append_pull_weight_trigger_as_input(kwargs, detected_device)
kwargs = self._append_pull_weight_trigger_as_input(kwargs, detected_device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
self._flattened_module.named_parameters()
)
else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers

_, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers(
self._graph_initializers,
param_to_append_as_onnx_graph_inputs,
self._graph_builder.get_graph_info().user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
Expand Down Expand Up @@ -648,25 +660,43 @@ def _enable_conditional_optimizations(
self._runtime_inspector.disable_input_inspector()

def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device):
from ._mem_efficient_grad_mgmt import (
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)
if self._runtime_options.enable_zero_stage3_support:
from ._zero_stage3_compatibility import (
STAGE3_PULL_WEIGHT_TRIGGER_NAME,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
)

new_kwargs = {
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: torch.zeros(
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE),
kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros(
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()
}

# Then the trigger input will be the first user input.
return {
**new_kwargs,
**kwargs,
}
return kwargs

if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import (
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)

new_kwargs = {
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: torch.zeros(
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()
}

# Then the trigger input will be the first user input.
return {
**new_kwargs,
**kwargs,
}

return kwargs

def _log_feature_stats(self):
if get_rank() != 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,20 @@ def forward(self, *inputs, **kwargs):
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
self._append_pull_weight_trigger_as_input(kwargs, self._device)
kwargs = self._append_pull_weight_trigger_as_input(kwargs, self._device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
self._flattened_module.named_parameters()
)
else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers

prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
self._graph_initializers,
param_to_append_as_onnx_graph_inputs,
pengwa marked this conversation as resolved.
Show resolved Hide resolved
self._graph_info.user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
Expand Down
12 changes: 6 additions & 6 deletions orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ def _expand_inputs(current_input, non_none_inputs, name=""):
)

# params is a list of all initializers known to the onnx graph
# if zero_stage3_offload_param_map:
# for p in params:
# if p not in zero_stage3_offload_param_map.values():
# result.append(p)
# else:
# result.extend(params)
if zero_stage3_offload_param_map:
for p in params:
if p not in zero_stage3_offload_param_map.values():
result.append(p)
else:
result.extend(params)

if rt_inspector.memory_ob.is_enabled() and not rt_inspector.memory_ob.symbolic_dim_collecting_completed:
rt_inspector.memory_ob.collect_symbolic_dim_values(input_info.dynamic_axes, onnx_input_to_value_map)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE = [1]


def get_params_connected_to_pull_param_trigger(named_params: Dict[str, torch.nn.parameter.Parameter]):
return {k: v for k, v in named_params if v.requires_grad}


def get_params_not_connected_to_pull_param_trigger(named_params: Dict[str, torch.nn.parameter.Parameter]):
return [v for k, v in named_params if not v.requires_grad]


def post_processing_enable_mem_efficient_training(
exported_model: ModelProto,
named_params: Dict[str, torch.nn.parameter.Parameter],
Expand All @@ -29,7 +37,7 @@ def post_processing_enable_mem_efficient_training(
exported_model (ModelProto): The exported model.
named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The full parameter map.
"""
trainable_named_params = {k: v for k, v in named_params if v.requires_grad}
trainable_named_params = get_params_connected_to_pull_param_trigger(named_params)

# Create weight retrieving function using trainable_named_params.
param_pull_trigger_func_class = _create_param_trigger_function(trainable_named_params)
Expand Down Expand Up @@ -75,7 +83,8 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
)

graph_inputs_to_remove = []
for graph_input in reversed(exported_model.graph.input):
input_offset = 0
for graph_input in exported_model.graph.input:
if graph_input.name not in trainable_named_params:
continue

Expand Down Expand Up @@ -110,7 +119,8 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
training_mode=1,
safe_run_mode=0,
)
exported_model.graph.node.insert(0, new_node)
exported_model.graph.node.insert(input_offset, new_node)
input_offset += 1

# Delete exported_model.graph.input
for input_to_remove in graph_inputs_to_remove:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,20 @@ def forward(self, *inputs, **kwargs):
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
self._append_pull_weight_trigger_as_input(kwargs, self._device)
kwargs = self._append_pull_weight_trigger_as_input(kwargs, self._device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
self._flattened_module.named_parameters()
)
else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers

prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
self._graph_initializers,
param_to_append_as_onnx_graph_inputs,
self._graph_info.user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
Expand Down
Loading