Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sharding constraint for the output tensor in the model #18536

Merged
merged 11 commits into from
Oct 6, 2023
2 changes: 1 addition & 1 deletion keras/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _initialize(self, value):

def _direct_assign(self, value):
if getattr(self, "_layout", None) is not None:
value = distribution_lib.distribute_value(value, self._layout)
value = distribution_lib.distribute_variable(value, self._layout)
self._value = value

def _convert_to_tensor(self, value, dtype=None):
Expand Down
39 changes: 36 additions & 3 deletions keras/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import jax
import numpy as np

from keras.utils import jax_utils


def list_devices(device_type=None):
"""Return all the available devices based on the device type.
Expand All @@ -27,8 +29,33 @@ def list_devices(device_type=None):
return [f"{device.device_kind}:{device.id}" for device in jax_devices]


def distribute_value(value, tensor_layout):
"""Distribute the value based on the layout.
def distribute_variable(value, tensor_layout):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit weird to have distribute_variable(value, tensor_layout) if we also have distribute_tensor(value, tensor_layout). I suggest switching to distribute_variable(value, layout) and distribute_tensor(value, layout)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"""Create a distributed variable for JAX.

Since JAX doesn't have variable class, this will just return a jax.Array
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a variable class

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use backticks for code keywords.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

with the corresponding layout/sharding specified.

Note that this function should be used in eager context, not in jitted
function.

Args:
value: the initial value of the variable.
tensor_layout: `TensorLayout` for the created variable, or a
`jax.sharding.Sharding` instance.

Returns:
jax.Array which is the distributed variable.
"""
if not isinstance(tensor_layout, jax.sharding.Sharding):
tensor_layout = _to_jax_layout(tensor_layout)
return jax.device_put(value, tensor_layout)


def distribute_tensor(tensor, tensor_layout):
"""Distribute the tensor based on the layout.

Note that this function can be used both in eager context, or within a
jitted function.

Args:
value: `jax.Array` that need to be distributed.
Expand All @@ -40,7 +67,13 @@ def distribute_value(value, tensor_layout):
"""
if not isinstance(tensor_layout, jax.sharding.Sharding):
tensor_layout = _to_jax_layout(tensor_layout)
return jax.device_put(value, tensor_layout)

# TODO(scottzhu): This might not be a cheap check, we should consider
# have some proper JAX API for doing this check.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I consulted mattjj and that is not something they are considering.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. Thanks for the confirmation.

if jax_utils.is_in_jax_tracing_scope():
return jax.lax.with_sharding_constraint(tensor, tensor_layout)
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove the indent block and just do return

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return jax.device_put(tensor, tensor_layout)


def _to_jax_device(device_id):
Expand Down
2 changes: 1 addition & 1 deletion keras/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def _distribute_data(self, data):

def distribute_single_value(d):
layout = distribution.get_data_layout(d.shape)
return jax_distribution_lib.distribute_value(d, layout)
return jax_distribution_lib.distribute_tensor(d, layout)

return jax.tree_util.tree_map(distribute_single_value, data)
else:
Expand Down
44 changes: 44 additions & 0 deletions keras/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np

from keras.api_export import keras_export
from keras.backend import KerasTensor
from keras.backend import distribution_lib
from keras.backend.common import global_state

Expand Down Expand Up @@ -170,6 +171,7 @@ class Distribution:

1. Distribute the model variables to a `DeviceMesh`.
2. Distribute the input data to a `DeviceMesh`.
3. Distribute an intermediate state tensor in the model.

It can create a context scope so that the framework to properly detect the
`Distribution` and distribute the variable/data accordingly.
Expand Down Expand Up @@ -205,6 +207,19 @@ def get_variable_layout(self, variable):
"""
raise NotImplementedError()

def get_tensor_layout(self, path):
"""Retrieve the `TensorLayout` for the intermediate tensor.

Args:
path: a string path for the correspoding tensor.

return:
The `TensorLayout` for the intermediate tensor, which can be used
by `backend.relayout()` to reshard the tensor. Could also return
None.
"""
raise NotImplementedError()

@contextlib.contextmanager
def scope(self):
"""Context manager to make the `Distribution` current."""
Expand Down Expand Up @@ -296,6 +311,10 @@ def get_variable_layout(self, variable):
variable_shard_spec = [None] * len(variable.shape)
return TensorLayout(variable_shard_spec, self.device_mesh)

def get_tensor_layout(self, path):
# For data parallel training, the intermediate state is not changed.
return None


@keras_export("keras.distribution.ModelParallel")
class ModelParallel(Distribution):
Expand Down Expand Up @@ -393,6 +412,9 @@ def get_variable_layout(self, variable):
variable_shard_spec = [None] * len(variable.shape)
return TensorLayout(variable_shard_spec, self.device_mesh)

def get_tensor_layout(self, path):
return self._layout_map[path]


@keras_export("keras.distribution.LayoutMap")
class LayoutMap(collections.abc.MutableMapping):
Expand Down Expand Up @@ -507,6 +529,28 @@ def _maybe_populate_device_mesh(self, layout):
LayoutMap.get.__doc__ = LayoutMap.__getitem__.__doc__


@keras_export("keras.distribution.distribute_tensor")
def distribute_tensor(tensor, tensor_layout):
"""Change the layout of a Tensor value in the jit function execution.

Note that this might not work outside of the jitted function for certain
backend. To change the layout of a value eagerly, please use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work in both situations in JAX, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works for JAX, but might not for all the other backend, eg tf.dtensor.relayout() will only work with dtensor instance and also in a tf.function, and pytorch is also unknown for now.

I think we will update the docstring once we have all the backend implemented, and reflect the final behavior.

`backend.distribution_lib.distribute_value`.

Args:
tensor: a Tensor to change the layout.
tensor_layout: TensorLayout to be applied on the value.

Returns:
a new value with the specified tensor layout.
"""
if isinstance(tensor, KerasTensor):
# keras tensor is only used for building functional model, and can't be
# used to alter layout/sharding.
return tensor
return distribution_lib.distribute_tensor(tensor, tensor_layout)


@keras_export("keras.distribution.distribution")
def distribution():
"""Retrieve the current distribution from global context."""
Expand Down
122 changes: 122 additions & 0 deletions keras/distribution/distribution_lib_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test for distribution_lib.py."""

import functools
import os
from unittest import mock

Expand Down Expand Up @@ -192,6 +193,15 @@ def test_get_variable_layout(self):
self.assertIs(variable_layout.device_mesh, self.device_mesh)
self.assertEqual(variable_layout.axes, (None,))

def test_get_tensor_layout(self):
distribution = distribution_lib.DataParallel(
device_mesh=self.device_mesh
)

path = "path/to/tensor"
tensor_layout = distribution.get_tensor_layout(path)
self.assertIsNone(tensor_layout)


class ModelParallelDistributionTest(testing.TestCase):
def setUp(self):
Expand Down Expand Up @@ -239,6 +249,22 @@ def test_distribute_data(self):
self.assertIs(data_layout.device_mesh, self.device_mesh)
self.assertEqual(data_layout.axes, ("data", None, None))

def test_get_tensor_layout(self):
layout_map = distribution_lib.LayoutMap(self.device_mesh)
layout_map[".*kernel"] = distribution_lib.TensorLayout([None, "model"])
layout_map[".*bias"] = distribution_lib.TensorLayout(["model"])
layout_map["/model/layer/tensor"] = ("data", None)

distribution = distribution_lib.ModelParallel(
self.device_mesh, layout_map, batch_dim_name="data"
)
layout = distribution.get_tensor_layout("/model/layer/tensor")
self.assertIs(layout.device_mesh, self.device_mesh)
self.assertEqual(layout.axes, ("data", None))

layout = distribution.get_tensor_layout("/model/layer/other_tensor")
self.assertIsNone(layout)


class LayoutMapTest(testing.TestCase):
def setUp(self):
Expand Down Expand Up @@ -362,6 +388,30 @@ def test_list_devices(self):
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)

def test_distribute_tensor(self):
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
)

inputs = jax.numpy.array(np.random.normal(size=(16, 8)))
target_layout = jax.sharding.NamedSharding(
jax_mesh, jax.sharding.PartitionSpec("batch", None)
)

@functools.partial(jax.jit, static_argnames="target_layout")
def test_function(inputs, target_layout):
return distribution_lib.distribute_tensor(inputs, target_layout)

result = test_function(inputs, target_layout)
# Note that the returned tensor has a different sharding implementation
# which is GSPMDSharding, but it should be equivalent as the target
# layout specified.
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

# Test without jit
result = distribution_lib.distribute_tensor(inputs, target_layout)
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_to_jax_mesh(self):
devices = [f"cpu:{i}" for i in range(8)]
shape = (4, 2)
Expand Down Expand Up @@ -499,6 +549,78 @@ def test_e2e_model_parallel_model(self):
model.compile(loss="mse")
model.fit(inputs, labels)

def test_e2e_model_parallel_with_output_sharding(self):
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
shape, axis_names, backend_dlib.list_devices()
)

layout_map = distribution_lib.LayoutMap(device_mesh)
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
[None, "model"]
)
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])
# Force the dense layer output to be batch parallel only, and not
# sharded on model dimension.
layout_map[".*dense.*output"] = ("batch", None)

distribution = distribution_lib.ModelParallel(
device_mesh, layout_map, batch_dim_name="batch"
)
sharding_capture = ShardingCaptureLayer()
with distribution.scope():
inputs = layers.Input(shape=[28, 28, 1])
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = sharding_capture(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = models.Model(inputs=inputs, outputs=y)

for weight in model.weights:
if "kernel" in weight.name:
self.assertEqual(weight._value.sharding.spec, (None, "model"))
elif "bias" in weight.name:
self.assertEqual(weight._value.sharding.spec, ("model",))
else:
self.assertTrue(weight._value.sharding.is_fully_replicated)

inputs = np.random.normal(size=(32, 28, 28, 1))
labels = np.random.normal(size=(32, 10))

with distribution.scope():
model.compile(loss="mse")
model.fit(inputs, labels)

# Note that the intermediate_tensor_layout is only captured during the
# actual training, and not at the model building time.
intermediate_tensor_layout = jax.sharding.NamedSharding(
backend_dlib._to_jax_mesh(distribution.device_mesh),
jax.sharding.PartitionSpec("batch", None),
)
self.assertTrue(
sharding_capture.captured_input_sharding.is_equivalent_to(
intermediate_tensor_layout, ndim=2
)
)


class ShardingCaptureLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.captured_input_sharding = None
self.supports_masking = True

def call(self, inputs):
jax.debug.inspect_array_sharding(
inputs, callback=lambda x: self.capture_input_sharding(x)
)
return inputs

def capture_input_sharding(self, sharding):
self.captured_input_sharding = sharding


# @pytest.mark.skipif(
# backend.backend() != "tensorflow",
Expand Down
15 changes: 15 additions & 0 deletions keras/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from keras.api_export import keras_export
from keras.backend import KerasTensor
from keras.backend.common import global_state
from keras.backend.common.name_scope import current_path
from keras.distribution import distribution_lib
from keras.layers import input_spec
from keras.metrics.metric import Metric
from keras.ops.operation import Operation
Expand Down Expand Up @@ -808,6 +810,19 @@ def maybe_convert(x):
outputs = super().__call__(*args, **kwargs)
else:
outputs = super().__call__(*args, **kwargs)
# Change the layout for the layer output if needed.
# This is useful for relayout intermediate tensor in the model
# to achieve the optimal performance.
distribution = distribution_lib.distribution()
if distribution is not None:
current_layer_path = current_path()
current_layer_path += "/output"
layout = distribution.get_tensor_layout(current_layer_path)
Copy link
Member

@fchollet fchollet Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for setting the layout of the output of a Dense layer in a subclassed model, you'd do like layout_map["model/layers/dense_1"] = (...)?

  1. Is there a risk of confusing variable layouts and intermediate tensor layouts?
  2. Should we be more specific, e.g. layout_map["model/layers/dense_1/output"] = (...) ? This could also leave the door open for input if ever needed.
  3. Is the full path too much information? What about layout_map["dense_1/output"] = (...)? Is that confusing?

Copy link
Member Author

@qlzh727 qlzh727 Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for setting the layout of the output of a Dense layer in a subclassed model, you'd do like layout_map["model/layers/dense_1"] = (...)?

Correct, its based on the path/name scope of the subclass model.

  1. Is there a risk of confusing variable layouts and intermediate tensor layouts?

It does. And my original intent was actually same as option 2.

  1. Should we be more specific, e.g. layout_map["model/layers/dense_1/output"] = (...) ? This could also leave the door open for input if ever needed.

That will definitely make the it more explicit. And also, this open the option for mapping to any intermediate keras operations within the layer body.

  1. Is the full path too much information? What about layout_map["dense_1/output"] = (...)? Is that confusing?

Matt has the same question, and he was proposing the use regex.search instead of regex.match so that user can skip the prefix. My original implementation was trying to be a bit strict, so that the layout won't accidentally map to unwanted weights. In the case that there are overlapping rule that apply to same weights, currently the first one wins. Maybe we can take the regex.match approach, and raise an error when multiple rules is mapped to the same weights/tensor. (Probably I will do this in a separate PR.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can take the regex.match approach, and raise an error when multiple rules is mapped to the same weights/tensor

I think that's a good idea, we can use search and then make sure that each variable matches at most 1 rule.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, let me do this in a separate PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if layout:
outputs = distribution_lib.distribute_tensor(
outputs, layout
)

if not self.built:
self.built = True
# Record activity regularizer loss.
Expand Down