Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule memory improvement #18924

Merged
merged 17 commits into from
Jan 16, 2024
Prev Previous commit
Next Next commit
fix ci
  • Loading branch information
pengwa committed Dec 26, 2023
commit b3fdea9ecc23290d8617b1734a438b1d5e908545
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,7 @@ def _initialize_graph_builder(self):
from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME

# Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph.

input_names_require_grad.insert(0, MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)
input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)
grad_builder_config.input_names_require_grad = input_names_require_grad
grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING
grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization
Expand Down Expand Up @@ -615,7 +614,7 @@ def _enable_conditional_optimizations(
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
kwargs = self._append_pull_weight_trigger_as_input(kwargs, detected_device)
self._append_pull_weight_trigger_as_input(kwargs, detected_device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
Expand Down Expand Up @@ -673,30 +672,18 @@ def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.devic
device=device,
).requires_grad_()

return kwargs

if self._runtime_options.enable_mem_efficient_grad_management:
from ._mem_efficient_grad_mgmt import (
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)

new_kwargs = {
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME: torch.zeros(
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()
}

# Then the trigger input will be the first user input.
return {
**new_kwargs,
**kwargs,
}

return kwargs
kwargs[MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME] = torch.zeros(
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()

def _log_feature_stats(self):
if get_rank() != 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def forward(self, *inputs, **kwargs):
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
kwargs = self._append_pull_weight_trigger_as_input(kwargs, self._device)
self._append_pull_weight_trigger_as_input(kwargs, self._device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ def _get_param_pull_trigger_name(param_name: str) -> str:
exported_model.graph.input.remove(input_to_remove)

# Re-order graph input to make sure the weight pull trigger is the first user input.
exported_model.graph.input.insert(0, inputs[0])
offset = 0 # Find the first trainable param, insert the new input before it, as part of user inputs.
for input in exported_model.graph.input:
if input.name in named_params:
break
offset += 1
exported_model.graph.input.insert(offset, inputs[0])
exported_model.graph.node.insert(0, weight_pull_node)

return exported_model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def forward(self, *inputs, **kwargs):
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
kwargs = self._append_pull_weight_trigger_as_input(kwargs, self._device)
self._append_pull_weight_trigger_as_input(kwargs, self._device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
Expand Down