-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sharding constraint for the output tensor in the model #18536
Changes from all commits
642857f
f7d7c98
01364c2
9025681
76bc579
6df8693
43f9fc8
478233a
febe922
3936ee8
2fc70fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
import numpy as np | ||
|
||
from keras.api_export import keras_export | ||
from keras.backend import KerasTensor | ||
from keras.backend import distribution_lib | ||
from keras.backend.common import global_state | ||
|
||
|
@@ -170,6 +171,7 @@ class Distribution: | |
|
||
1. Distribute the model variables to a `DeviceMesh`. | ||
2. Distribute the input data to a `DeviceMesh`. | ||
3. Distribute an intermediate state tensor in the model. | ||
|
||
It can create a context scope so that the framework to properly detect the | ||
`Distribution` and distribute the variable/data accordingly. | ||
|
@@ -205,6 +207,19 @@ def get_variable_layout(self, variable): | |
""" | ||
raise NotImplementedError() | ||
|
||
def get_tensor_layout(self, path): | ||
"""Retrieve the `TensorLayout` for the intermediate tensor. | ||
|
||
Args: | ||
path: a string path for the correspoding tensor. | ||
|
||
return: | ||
The `TensorLayout` for the intermediate tensor, which can be used | ||
by `backend.relayout()` to reshard the tensor. Could also return | ||
None. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@contextlib.contextmanager | ||
def scope(self): | ||
"""Context manager to make the `Distribution` current.""" | ||
|
@@ -296,6 +311,10 @@ def get_variable_layout(self, variable): | |
variable_shard_spec = [None] * len(variable.shape) | ||
return TensorLayout(variable_shard_spec, self.device_mesh) | ||
|
||
def get_tensor_layout(self, path): | ||
# For data parallel training, the intermediate state is not changed. | ||
return None | ||
|
||
|
||
@keras_export("keras.distribution.ModelParallel") | ||
class ModelParallel(Distribution): | ||
|
@@ -393,6 +412,9 @@ def get_variable_layout(self, variable): | |
variable_shard_spec = [None] * len(variable.shape) | ||
return TensorLayout(variable_shard_spec, self.device_mesh) | ||
|
||
def get_tensor_layout(self, path): | ||
return self._layout_map[path] | ||
|
||
|
||
@keras_export("keras.distribution.LayoutMap") | ||
class LayoutMap(collections.abc.MutableMapping): | ||
|
@@ -507,6 +529,28 @@ def _maybe_populate_device_mesh(self, layout): | |
LayoutMap.get.__doc__ = LayoutMap.__getitem__.__doc__ | ||
|
||
|
||
@keras_export("keras.distribution.distribute_tensor") | ||
def distribute_tensor(tensor, layout): | ||
"""Change the layout of a Tensor value in the jit function execution. | ||
|
||
Note that this might not work outside of the jitted function for certain | ||
backend. To change the layout of a value eagerly, please use | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should work in both situations in JAX, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It works for JAX, but might not for all the other backend, eg tf.dtensor.relayout() will only work with dtensor instance and also in a tf.function, and pytorch is also unknown for now. I think we will update the docstring once we have all the backend implemented, and reflect the final behavior. |
||
`backend.distribution_lib.distribute_value`. | ||
|
||
Args: | ||
tensor: a Tensor to change the layout. | ||
layout: `TensorLayout` to be applied on the value. | ||
|
||
Returns: | ||
a new value with the specified tensor layout. | ||
""" | ||
if isinstance(tensor, KerasTensor): | ||
# keras tensor is only used for building functional model, and can't be | ||
# used to alter layout/sharding. | ||
return tensor | ||
return distribution_lib.distribute_tensor(tensor, layout) | ||
|
||
|
||
@keras_export("keras.distribution.distribution") | ||
def distribution(): | ||
"""Retrieve the current distribution from global context.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,8 @@ | |
from keras.api_export import keras_export | ||
from keras.backend import KerasTensor | ||
from keras.backend.common import global_state | ||
from keras.backend.common.name_scope import current_path | ||
from keras.distribution import distribution_lib | ||
from keras.layers import input_spec | ||
from keras.metrics.metric import Metric | ||
from keras.ops.operation import Operation | ||
|
@@ -808,6 +810,19 @@ def maybe_convert(x): | |
outputs = super().__call__(*args, **kwargs) | ||
else: | ||
outputs = super().__call__(*args, **kwargs) | ||
# Change the layout for the layer output if needed. | ||
# This is useful for relayout intermediate tensor in the model | ||
# to achieve the optimal performance. | ||
distribution = distribution_lib.distribution() | ||
if distribution is not None: | ||
current_layer_path = current_path() | ||
current_layer_path += "/output" | ||
layout = distribution.get_tensor_layout(current_layer_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for setting the layout of the output of a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Correct, its based on the path/name scope of the subclass model.
It does. And my original intent was actually same as option 2.
That will definitely make the it more explicit. And also, this open the option for mapping to any intermediate keras operations within the layer body.
Matt has the same question, and he was proposing the use regex.search instead of regex.match so that user can skip the prefix. My original implementation was trying to be a bit strict, so that the layout won't accidentally map to unwanted weights. In the case that there are overlapping rule that apply to same weights, currently the first one wins. Maybe we can take the regex.match approach, and raise an error when multiple rules is mapped to the same weights/tensor. (Probably I will do this in a separate PR.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think that's a good idea, we can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack, let me do this in a separate PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
if layout: | ||
outputs = distribution_lib.distribute_tensor( | ||
outputs, layout | ||
) | ||
|
||
if not self.built: | ||
self.built = True | ||
# Record activity regularizer loss. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I consulted mattjj and that is not something they are considering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack. Thanks for the confirmation.