-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sharding constraint for the output tensor in the model #18536
Changes from 9 commits
642857f
f7d7c98
01364c2
9025681
76bc579
6df8693
43f9fc8
478233a
febe922
3936ee8
2fc70fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,8 @@ | |
import jax | ||
import numpy as np | ||
|
||
from keras.utils import jax_utils | ||
|
||
|
||
def list_devices(device_type=None): | ||
"""Return all the available devices based on the device type. | ||
|
@@ -27,8 +29,33 @@ def list_devices(device_type=None): | |
return [f"{device.device_kind}:{device.id}" for device in jax_devices] | ||
|
||
|
||
def distribute_value(value, tensor_layout): | ||
"""Distribute the value based on the layout. | ||
def distribute_variable(value, tensor_layout): | ||
"""Create a distributed variable for JAX. | ||
|
||
Since JAX doesn't have variable class, this will just return a jax.Array | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a variable class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use backticks for code keywords. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
with the corresponding layout/sharding specified. | ||
|
||
Note that this function should be used in eager context, not in jitted | ||
function. | ||
|
||
Args: | ||
value: the initial value of the variable. | ||
tensor_layout: `TensorLayout` for the created variable, or a | ||
`jax.sharding.Sharding` instance. | ||
|
||
Returns: | ||
jax.Array which is the distributed variable. | ||
""" | ||
if not isinstance(tensor_layout, jax.sharding.Sharding): | ||
tensor_layout = _to_jax_layout(tensor_layout) | ||
return jax.device_put(value, tensor_layout) | ||
|
||
|
||
def distribute_tensor(tensor, tensor_layout): | ||
"""Distribute the tensor based on the layout. | ||
|
||
Note that this function can be used both in eager context, or within a | ||
jitted function. | ||
|
||
Args: | ||
value: `jax.Array` that need to be distributed. | ||
|
@@ -40,7 +67,13 @@ def distribute_value(value, tensor_layout): | |
""" | ||
if not isinstance(tensor_layout, jax.sharding.Sharding): | ||
tensor_layout = _to_jax_layout(tensor_layout) | ||
return jax.device_put(value, tensor_layout) | ||
|
||
# TODO(scottzhu): This might not be a cheap check, we should consider | ||
# have some proper JAX API for doing this check. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I consulted mattjj and that is not something they are considering. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack. Thanks for the confirmation. |
||
if jax_utils.is_in_jax_tracing_scope(): | ||
return jax.lax.with_sharding_constraint(tensor, tensor_layout) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can remove the indent block and just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
return jax.device_put(tensor, tensor_layout) | ||
|
||
|
||
def _to_jax_device(device_id): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
import numpy as np | ||
|
||
from keras.api_export import keras_export | ||
from keras.backend import KerasTensor | ||
from keras.backend import distribution_lib | ||
from keras.backend.common import global_state | ||
|
||
|
@@ -170,6 +171,7 @@ class Distribution: | |
|
||
1. Distribute the model variables to a `DeviceMesh`. | ||
2. Distribute the input data to a `DeviceMesh`. | ||
3. Distribute an intermediate state tensor in the model. | ||
|
||
It can create a context scope so that the framework to properly detect the | ||
`Distribution` and distribute the variable/data accordingly. | ||
|
@@ -205,6 +207,19 @@ def get_variable_layout(self, variable): | |
""" | ||
raise NotImplementedError() | ||
|
||
def get_tensor_layout(self, path): | ||
"""Retrieve the `TensorLayout` for the intermediate tensor. | ||
|
||
Args: | ||
path: a string path for the correspoding tensor. | ||
|
||
return: | ||
The `TensorLayout` for the intermediate tensor, which can be used | ||
by `backend.relayout()` to reshard the tensor. Could also return | ||
None. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@contextlib.contextmanager | ||
def scope(self): | ||
"""Context manager to make the `Distribution` current.""" | ||
|
@@ -296,6 +311,10 @@ def get_variable_layout(self, variable): | |
variable_shard_spec = [None] * len(variable.shape) | ||
return TensorLayout(variable_shard_spec, self.device_mesh) | ||
|
||
def get_tensor_layout(self, path): | ||
# For data parallel training, the intermediate state is not changed. | ||
return None | ||
|
||
|
||
@keras_export("keras.distribution.ModelParallel") | ||
class ModelParallel(Distribution): | ||
|
@@ -393,6 +412,9 @@ def get_variable_layout(self, variable): | |
variable_shard_spec = [None] * len(variable.shape) | ||
return TensorLayout(variable_shard_spec, self.device_mesh) | ||
|
||
def get_tensor_layout(self, path): | ||
return self._layout_map[path] | ||
|
||
|
||
@keras_export("keras.distribution.LayoutMap") | ||
class LayoutMap(collections.abc.MutableMapping): | ||
|
@@ -507,6 +529,28 @@ def _maybe_populate_device_mesh(self, layout): | |
LayoutMap.get.__doc__ = LayoutMap.__getitem__.__doc__ | ||
|
||
|
||
@keras_export("keras.distribution.distribute_tensor") | ||
def distribute_tensor(tensor, tensor_layout): | ||
"""Change the layout of a Tensor value in the jit function execution. | ||
|
||
Note that this might not work outside of the jitted function for certain | ||
backend. To change the layout of a value eagerly, please use | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should work in both situations in JAX, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It works for JAX, but might not for all the other backend, eg tf.dtensor.relayout() will only work with dtensor instance and also in a tf.function, and pytorch is also unknown for now. I think we will update the docstring once we have all the backend implemented, and reflect the final behavior. |
||
`backend.distribution_lib.distribute_value`. | ||
|
||
Args: | ||
tensor: a Tensor to change the layout. | ||
tensor_layout: TensorLayout to be applied on the value. | ||
|
||
Returns: | ||
a new value with the specified tensor layout. | ||
""" | ||
if isinstance(tensor, KerasTensor): | ||
# keras tensor is only used for building functional model, and can't be | ||
# used to alter layout/sharding. | ||
return tensor | ||
return distribution_lib.distribute_tensor(tensor, tensor_layout) | ||
|
||
|
||
@keras_export("keras.distribution.distribution") | ||
def distribution(): | ||
"""Retrieve the current distribution from global context.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,8 @@ | |
from keras.api_export import keras_export | ||
from keras.backend import KerasTensor | ||
from keras.backend.common import global_state | ||
from keras.backend.common.name_scope import current_path | ||
from keras.distribution import distribution_lib | ||
from keras.layers import input_spec | ||
from keras.metrics.metric import Metric | ||
from keras.ops.operation import Operation | ||
|
@@ -808,6 +810,19 @@ def maybe_convert(x): | |
outputs = super().__call__(*args, **kwargs) | ||
else: | ||
outputs = super().__call__(*args, **kwargs) | ||
# Change the layout for the layer output if needed. | ||
# This is useful for relayout intermediate tensor in the model | ||
# to achieve the optimal performance. | ||
distribution = distribution_lib.distribution() | ||
if distribution is not None: | ||
current_layer_path = current_path() | ||
current_layer_path += "/output" | ||
layout = distribution.get_tensor_layout(current_layer_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for setting the layout of the output of a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Correct, its based on the path/name scope of the subclass model.
It does. And my original intent was actually same as option 2.
That will definitely make the it more explicit. And also, this open the option for mapping to any intermediate keras operations within the layer body.
Matt has the same question, and he was proposing the use regex.search instead of regex.match so that user can skip the prefix. My original implementation was trying to be a bit strict, so that the layout won't accidentally map to unwanted weights. In the case that there are overlapping rule that apply to same weights, currently the first one wins. Maybe we can take the regex.match approach, and raise an error when multiple rules is mapped to the same weights/tensor. (Probably I will do this in a separate PR.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think that's a good idea, we can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack, let me do this in a separate PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
if layout: | ||
outputs = distribution_lib.distribute_tensor( | ||
outputs, layout | ||
) | ||
|
||
if not self.built: | ||
self.built = True | ||
# Record activity regularizer loss. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a bit weird to have
distribute_variable(value, tensor_layout)
if we also havedistribute_tensor(value, tensor_layout)
. I suggest switching todistribute_variable(value, layout)
anddistribute_tensor(value, layout)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.