Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRU Training and GRU Gradient Kernels #16929

Merged
merged 10 commits into from
Aug 10, 2023
Merged

Conversation

baijumeswani
Copy link
Contributor

This pull request

  • Introduces the GRUTraining and the GRUGrad operators in the com.microsoft domain.
  • Updates the deep cpu gru implementation so as to support the training mode (where an extra output is generated which is the zrh gate computations)
  • Adds a graph transformer that replaces the GRU node with the GRUTraining node. The GRUTraining node generates a new output (the zrh gate computations) which is needed for the gradient computation.
  • Adds the kernel implementation for the GRUGrad op.

Pending work: Make GRUGrad and LSTMGrad implementation kernels parallel across the batch axis.

@baijumeswani baijumeswani added the training issues related to ONNX Runtime training; typically submitted using template label Jul 31, 2023
Copy link
Contributor

@pengwa pengwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, few comments.

@baijumeswani baijumeswani merged commit 31cbd63 into main Aug 10, 2023
91 of 94 checks passed
@baijumeswani baijumeswani deleted the baijumeswani/gru-gradient branch August 10, 2023 04:24
@baijumeswani
Copy link
Contributor Author

Thanks so much for the review @pengwa!

kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants