Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule memory improvement #18924

Merged
merged 17 commits into from
Jan 16, 2024
Prev Previous commit
Next Next commit
fix again
  • Loading branch information
pengwa committed Dec 28, 2023
commit 28f7c9eb423d69f9e94d988be08cf7a0c48c1c48
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def __init__(

configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True)

# Will be reset everytime we re-initialize the graph builder.
# Be noted, we will never enable this feature for inference mode.
self._mem_efficient_grad_management_is_enabled = False

def _get_torch_gpu_allocator_function_addresses(self):
if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available():
# CPP extension to get torch GPU allocator's alloc and free function addresses
Expand Down Expand Up @@ -497,16 +501,22 @@ def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfigurati
def _initialize_graph_builder(self):
"""Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder"""

self._mem_efficient_grad_management_is_enabled = (
self._export_mode != torch.onnx.TrainingMode.EVAL
and self._runtime_options.enable_mem_efficient_grad_management
)

# We post process the exported model because the trainable parame might be changed, so this path is
# re-triggered by reinitialize_graph_builder.
exported_model = copy.deepcopy(self._onnx_models.exported_model)
self._onnx_models.processed_exported_model = exported_model
if self._runtime_options.enable_mem_efficient_grad_management:

if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training

# Override the options if model is not modified.
(
self._runtime_options.enable_mem_efficient_grad_management,
self._mem_efficient_grad_management_is_enabled,
exported_model,
) = post_processing_enable_mem_efficient_training(exported_model, self._flattened_module.named_parameters())

Expand Down Expand Up @@ -543,7 +553,7 @@ def _initialize_graph_builder(self):
# Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph.
input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME)

if self._runtime_options.enable_mem_efficient_grad_management:
if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME

# Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph.
Expand Down Expand Up @@ -635,14 +645,11 @@ def _enable_conditional_optimizations(
inputs, kwargs
)

if (
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled:
self._append_pull_weight_trigger_as_input(kwargs, detected_device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
Expand Down Expand Up @@ -697,7 +704,7 @@ def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.devic
device=device,
).requires_grad_()

if self._runtime_options.enable_mem_efficient_grad_management:
if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import (
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,11 @@ def forward(self, *inputs, **kwargs):

self._gradient_accumulation_manager.maybe_update_cache_before_run()

if (
self._runtime_options.enable_zero_stage3_support
or self._runtime_options.enable_mem_efficient_grad_management
):
if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled:
self._append_pull_weight_trigger_as_input(kwargs, self._device)

param_to_append_as_onnx_graph_inputs = []
if self._runtime_options.enable_mem_efficient_grad_management:
if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
Expand Down Expand Up @@ -506,7 +503,7 @@ def _reinitialize_graph_builder(self, input_info: _InputInfo):
if param.requires_grad and name in self._graph_initializer_names
}

if self._runtime_options.enable_mem_efficient_grad_management:
if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME

# Remove the inputs we added during model post-processing.
Expand Down
Loading