Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule GraphTransitionManager #19007

Merged
merged 39 commits into from
Jul 3, 2024
Merged

ORTModule GraphTransitionManager #19007

merged 39 commits into from
Jul 3, 2024

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Jan 4, 2024

Problem

Currently, the codebase contains some logics pertaining to model re-export checks and graph_builder reinitialization checks. Ideally, these operations should function akin to a state machine. However, upon inspecting the implementation, it becomes apparent that certain states are checked or set in various scattered locations. This fragmentation makes it challenging to comprehend when a re-export or re-initialization will be triggered. For optimal clarity and maintainability, it is advisable to consolidate these states into a cohesive component, rather than dispersing them within the current graph execution manager.

Furthermore, the process of model exports and post-export processing for stage 3 support or memory-efficient gradient management introduces considerable complexity. To enhance the codebase's structure, it would be beneficial to extract these intricate functionalities into a dedicated component, divorcing them from the current graph execution manager.

As part of the effort to improve the codebase, it's essential to address inconsistencies in handling input/output flatten/unflatten operations. Currently, there are several functions performing these operations recursively, each with slightly different implementations. This inconsistency leads to varying support for input/output data types and structures in different parts of the code. To rectify this, the proposed pull request simplifies these operations into a set of primitive functions, ensuring uniformity. This not only streamlines the code but also facilitates the maintenance of consistency when introducing bug fixes or supporting new data types. One thing to mention here: input output handling is deeply bound to the graph transition mentioned above, so it is difficult to make this change separately.

While acknowledging the complexity of these logics, it is reassuring that the codebase benefits from an extensive suite of unit tests that cover all possible branches. Despite the intricacies, ensuring the passage of all tests has been a time-intensive but necessary aspect of this development effort.

Design

Introduce GraphTransitionManager and put all model export and post-export processing logics in it.

  1. Re-export check
  2. Do export
  3. Re-post-export process check
  4. Do post-export process
  5. Return PostExportProcessedModelInfo, which contains all the information we need, to pass to ORT to build gradient graph (currently we do the same for training or evaluating, but ideally we should not do it for evaluating, let's keep this behavior as it is now, and make the change later).
          # Input names for the pre-gradient-build graph.
          # This may be different with the one in ExportedGraph since we may modify the graph inputs as needed
          # for example when memory efficient gradient management is enabled.
          self.onnx_graph_input_names: list[str] = onnx_graph_input_names
    
          # A subset of onnx_graph_input_names.
          # Input names that require gradients for the pre-gradient-build graph.
          self.onnx_graph_input_names_require_grad: list[str] = onnx_graph_input_names_require_grad
    
          # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names).
          # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name}
          # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}}
          self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] = onnx_graph_input_dynamic_axes_map
    
          self.buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict()
          self.onnx_graph_input_names_user_defined = (
              onnx_graph_input_names_user_defined  # The ONNX graph input names excluding the parameters, buffers.
          )
    
          # The ONNX graph input names excluding the parameters, buffers.
          self.onnx_graph_input_names_require_grad_user_defined = onnx_graph_input_names_require_grad_user_defined
    
          self._post_export_processed_model: onnx.ModelProto | None = post_export_processed_model
    
          # A function to access the input data from the args and kwargs.
          # If it is not None, the length is same as onnx_graph_input_names.
          # For i-th input name, we can use the i-th function to get the input data from args and kwargs.
          self.data_accessor: list[callable] | None = data_accessor
    
          # Used for unflattening the outputs from the ORT forward run.
          self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema```
    
    
    
    
    

The GraphTransitionManager instance is a property of GraphExecutionManager (e.g. TrainingManager or ``InferenceManager),

  1. Use 'self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs, kwargs)' to check whether the PyTorch module need a re-export or re-post-export-process.
  2. Use self._graph_transition_manager._post_export_processed_model_info.construct_inputs to construct the list of inputs used for ORT runs.
  3. Use self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs) to restore the outputs in original PyTorch output structure.

Motivation and Context

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Jan 4, 2024
@pengwa pengwa changed the title Refactor ORTModule model export/process and input/output handling Introduce GraphTransitionManager for ORTModule Jan 4, 2024
@pengwa pengwa changed the title Introduce GraphTransitionManager for ORTModule ORTModule GraphTransitionManager Jan 4, 2024
@pengwa pengwa requested a review from mindest January 5, 2024 02:54
pengwa added a commit that referenced this pull request Jan 16, 2024
## Dependency

#19007

## ORTModule memory efficient gradient management

Previously I have tried to solve the coarsed-grained gradient
accumulation/update problem in ORTModule with
#8979, while that
resolution somehow is not fully validated with DDP or there is user
hooks on the gradient accumulation on torch parameter.

This PR is addressing the problem in the similar approach as PR 8979,
e.g. trigger gradient accumulation once ORT computed the grad, but
instead of use a AccumulateGrad op, this time with a ONNX operator
PythonOp, internally it will call param.backward(grad), which will help
handle all related hooks correctly.


## Design

Check the details from


https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ

## Convergence Validation:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/ccf3a213-e815-4b23-b759-165033b2d9fe)

differences are on mostly 0.000x, sometimes 0.00x, which may comes from
the different order gradient apply happens before or after this change
(on deepspeed zero stage 2)


## TODO

Consolidate the logic with Stage3's similar logic.
Copy link
Contributor Author

@pengwa pengwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response.

@pengwa pengwa requested a review from wschin March 26, 2024 01:41
@wschin
Copy link
Contributor

wschin commented Apr 3, 2024

construct_inputs and restore_outputs can probably be done by calling tree_flatten and tree_unflatten in https://github.com/pytorch/pytorch/blob/15bd81bfafa86fec9d675e7f071c867c852ebe8f/torch/utils/_pytree.py#L799.

Copy link
Contributor

@wschin wschin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now. Thanks.

@pengwa
Copy link
Contributor Author

pengwa commented Jul 3, 2024

Thanks @wschin.

@pengwa pengwa merged commit 4932e04 into main Jul 3, 2024
96 checks passed
@pengwa pengwa deleted the pengwa/refactor_io branch July 3, 2024 02:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants