Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule memory improvement #18924

Merged
merged 17 commits into from
Jan 16, 2024
Merged
10 changes: 10 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e
export ORTMODULE_MEMORY_OPT_LEVEL=0
```

### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes.

```bash
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1 # Enable
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable
```

### 2.2 Memory Optimization

Q: *Want to run a bigger batch size?*
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ void IExecutionFrame::Init(gsl::span<const int> feed_mlvalue_idxs, gsl::span<con
const std::unordered_map<int, OrtValue>& initializers,
const std::function<bool(const std::string& name)>& is_initializer_sparse_func,
gsl::span<const OrtValue> fetches) {
ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size());
ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size(), "Get feed size: ", feeds.size(), " but expected feed size: ",
feed_mlvalue_idxs.size());
ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size());

// Need this for sparse conversions in host memory
Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2415,9 +2415,9 @@ def _infer_RotaryEmbedding(self, node): # noqa: N802

def _infer_PythonOp(self, node): # noqa: N802
output_tensor_types = get_attribute(node, "output_tensor_types")
assert output_tensor_types
assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute."
output_tensor_ranks = get_attribute(node, "output_tensor_ranks")
assert output_tensor_ranks
assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute."

from onnxruntime.capi._pybind_state import get_shape_inference_function

Expand All @@ -2438,7 +2438,10 @@ def _infer_PythonOp(self, node): # noqa: N802
input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
input_dtypes.append(input_dtype)
output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes)
assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1)
assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), (
f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, "
f"but expected {len(node.output) - 1} outputs."
)
for i in range(len(node.output) - 1):
output_index = i + 1
vi = self.known_vi_[node.output[output_index]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from ._io import _FlattenedModule, _InputInfo
from ._runtime_inspector import RuntimeInspector
from ._utils import check_function_has_param, get_rank
from ._zero_stage3_compatibility import stage3_export_context
from .options import DebugOptions, LogLevel, _MemoryOptimizationLevel, _RuntimeOptions
from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension

Expand Down Expand Up @@ -148,6 +147,10 @@ def __init__(

configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True)

# Will be reset everytime we re-initialize the graph builder.
# Be noted, we will never enable this feature for inference mode.
self._mem_efficient_grad_management_is_enabled = False

def _get_torch_gpu_allocator_function_addresses(self):
if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available():
# CPP extension to get torch GPU allocator's alloc and free function addresses
Expand Down Expand Up @@ -388,6 +391,8 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu
assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager"

try:
from ._zero_stage3_compatibility import stage3_export_context

with torch.no_grad(), stage3_export_context(self._runtime_options.enable_zero_stage3_support, self):
required_export_kwargs = {
"input_names": self._input_info.names,
Expand Down Expand Up @@ -496,9 +501,35 @@ def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfigurati
def _initialize_graph_builder(self):
"""Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder"""

self._mem_efficient_grad_management_is_enabled = (
self._export_mode != torch.onnx.TrainingMode.EVAL
and self._runtime_options.enable_mem_efficient_grad_management
)

# We post process the exported model because the trainable parame might be changed, so this path is
# re-triggered by reinitialize_graph_builder.
exported_model = copy.deepcopy(self._onnx_models.exported_model)
self._onnx_models.processed_exported_model = exported_model

if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import post_processing_enable_mem_efficient_training

# Override the options if model is not modified.
(
self._mem_efficient_grad_management_is_enabled,
exported_model,
) = post_processing_enable_mem_efficient_training(exported_model, self._flattened_module.named_parameters())

if self._runtime_options.run_symbolic_shape_infer:
exported_model = SymbolicShapeInference.infer_shapes(
exported_model, auto_merge=True, guess_output_rank=True
)

# All initializer names along with user inputs are a part of the onnx graph inputs
# since the onnx model was exported with the flag keep_initializers_as_inputs=True
onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input}
# We need to use the raw exported model here since the graph inputs include both user inputrs and
# parameters.
onnx_initializer_names = {p.name for p in exported_model.graph.input}

# TODO: PyTorch exporter bug: changes the initializer order in ONNX model
initializer_names = [
Expand All @@ -521,6 +552,13 @@ def _initialize_graph_builder(self):

# Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph.
input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME)

if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME

# Add mem efficient grad trigger name to require_grad_names, so that it will be included in the gradient graph.
input_names_require_grad.append(MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME)

grad_builder_config.input_names_require_grad = input_names_require_grad
grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING
grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization
Expand All @@ -532,12 +570,23 @@ def _initialize_graph_builder(self):

# It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way
# and are kept as they appear in the exported onnx model.
self._graph_builder.initialize(self._onnx_models.exported_model.SerializeToString(), grad_builder_config)
self._graph_builder.initialize(exported_model.SerializeToString(), grad_builder_config)

raw_onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input}

raw_initializer_names = [
name for name, _ in self._flattened_module.named_parameters() if name in raw_onnx_initializer_names
]
raw_initializer_names_to_train = [
name
for name, param in self._flattened_module.named_parameters()
if param.requires_grad and name in raw_onnx_initializer_names
]

# TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train
# a set (unordered_set in the backend) that does not require a copy on each reference.
self._graph_initializer_names = set(initializer_names)
self._graph_initializer_names_to_train = set(initializer_names_to_train)
self._graph_initializer_names = set(raw_initializer_names)
self._graph_initializer_names_to_train = set(raw_initializer_names_to_train)

# Initializers can be cached and used since they are expected not to be re-instantiated
# between forward calls.
Expand Down Expand Up @@ -588,19 +637,29 @@ def _enable_conditional_optimizations(
# Enable data sparsity inspection if sparse optimizer is ON or user wants to print input density.
if self._runtime_options.enable_sparse_optimizer or self._runtime_options.print_input_density:
self._runtime_inspector.enable_input_inspector(
self._onnx_models.exported_model, self._graph_builder.get_graph_info().user_input_names
self._onnx_models.processed_exported_model, self._graph_builder.get_graph_info().user_input_names
)

if self._runtime_options.enable_sparse_optimizer:
detected_device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs(
inputs, kwargs
)

if self._runtime_options.enable_zero_stage3_support:
if self._runtime_options.enable_zero_stage3_support or self._mem_efficient_grad_management_is_enabled:
self._append_pull_weight_trigger_as_input(kwargs, detected_device)

param_to_append_as_onnx_graph_inputs = []
if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import get_params_not_connected_to_pull_param_trigger

param_to_append_as_onnx_graph_inputs = get_params_not_connected_to_pull_param_trigger(
self._flattened_module.named_parameters(), self._onnx_models.exported_model
)
else:
param_to_append_as_onnx_graph_inputs = self._graph_initializers

_, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers(
self._graph_initializers,
param_to_append_as_onnx_graph_inputs,
self._graph_builder.get_graph_info().user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
Expand Down Expand Up @@ -632,19 +691,31 @@ def _enable_conditional_optimizations(
self._runtime_inspector.disable_input_inspector()

def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device):
from ._zero_stage3_compatibility import (
STAGE3_PULL_WEIGHT_TRIGGER_NAME,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
)
if self._runtime_options.enable_zero_stage3_support:
from ._zero_stage3_compatibility import (
STAGE3_PULL_WEIGHT_TRIGGER_NAME,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
)

kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros(
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()
kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros(
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()

if self._mem_efficient_grad_management_is_enabled:
from ._mem_efficient_grad_mgmt import (
MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE,
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
)

return kwargs
kwargs[MEM_EFFICIENT_PARAM_TRIGGER_INPUT_NAME] = torch.zeros(
MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch_dtype(MEM_EFFICIENT_PARAM_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()

def _log_feature_stats(self):
if get_rank() != 0:
Expand Down
Loading
Loading