-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sharding constraint for the output tensor in the model #18536
Changes from 1 commit
642857f
f7d7c98
01364c2
9025681
76bc579
6df8693
43f9fc8
478233a
febe922
3936ee8
2fc70fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
import numpy as np | ||
|
||
from keras.api_export import keras_export | ||
from keras.backend import KerasTensor | ||
from keras.backend import distribution_lib | ||
from keras.backend.common import global_state | ||
|
||
|
@@ -211,7 +212,7 @@ def get_tensor_layout(self, path): | |
|
||
Args: | ||
path: a string path for the correspoding tensor. | ||
|
||
return: | ||
The `TensorLayout` for the intermediate tensor, which can be used | ||
by `backend.relayout()` to reshard the tensor. Could also return | ||
|
@@ -531,18 +532,22 @@ def _maybe_populate_device_mesh(self, layout): | |
@keras_export("keras.distribution.relayout") | ||
def relayout(value, tensor_layout): | ||
"""Change the layout of a Tensor value in the jit function execution. | ||
Note that this will only work within the jitted function. To change the | ||
layout of a value eagerly, please use | ||
|
||
Note that this might not work outside of the jitted function for certain | ||
backend. To change the layout of a value eagerly, please use | ||
`backend.distribution_lib.distribute_value`. | ||
|
||
Args: | ||
value: a Tensor to change the layout. | ||
tensor_layout: TensorLayout to be applied on the value. | ||
|
||
Returns: | ||
a new value with the specified tensor layout. | ||
""" | ||
if isinstance(value, KerasTensor): | ||
# keras tensor is only used for building functional model, and can't be | ||
# used to alter layout/sharding. | ||
return value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should KerasTensors still have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great question. The issue I hit when using kerasTensor with the jax sharding API is that it will always try to convert the KerasTensor to jax array, which result into error. It might make sense to add layout only when KerasTensor is a subclass or jax array or tf.Tensor. |
||
return distribution_lib.relayout(value, tensor_layout) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should work in both situations in JAX, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works for JAX, but might not for all the other backend, eg tf.dtensor.relayout() will only work with dtensor instance and also in a tf.function, and pytorch is also unknown for now.
I think we will update the docstring once we have all the backend implemented, and reflect the final behavior.