Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sharding constraint for the output tensor in the model #18536

Merged
merged 11 commits into from
Oct 6, 2023
Prev Previous commit
Next Next commit
Add unit test for sharding constraint
  • Loading branch information
qlzh727 committed Oct 2, 2023
commit 90256817c6f4b4f7a17b08ad80ddc685302bb916
15 changes: 10 additions & 5 deletions keras/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np

from keras.api_export import keras_export
from keras.backend import KerasTensor
from keras.backend import distribution_lib
from keras.backend.common import global_state

Expand Down Expand Up @@ -211,7 +212,7 @@ def get_tensor_layout(self, path):

Args:
path: a string path for the correspoding tensor.

return:
The `TensorLayout` for the intermediate tensor, which can be used
by `backend.relayout()` to reshard the tensor. Could also return
Expand Down Expand Up @@ -531,18 +532,22 @@ def _maybe_populate_device_mesh(self, layout):
@keras_export("keras.distribution.relayout")
def relayout(value, tensor_layout):
"""Change the layout of a Tensor value in the jit function execution.
Note that this will only work within the jitted function. To change the
layout of a value eagerly, please use

Note that this might not work outside of the jitted function for certain
backend. To change the layout of a value eagerly, please use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work in both situations in JAX, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works for JAX, but might not for all the other backend, eg tf.dtensor.relayout() will only work with dtensor instance and also in a tf.function, and pytorch is also unknown for now.

I think we will update the docstring once we have all the backend implemented, and reflect the final behavior.

`backend.distribution_lib.distribute_value`.

Args:
value: a Tensor to change the layout.
tensor_layout: TensorLayout to be applied on the value.

Returns:
a new value with the specified tensor layout.
"""
if isinstance(value, KerasTensor):
# keras tensor is only used for building functional model, and can't be
# used to alter layout/sharding.
return value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should KerasTensors still have a layout attribute? In case we need to read it on the tensor? Or is that not useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question. The issue I hit when using kerasTensor with the jax sharding API is that it will always try to convert the KerasTensor to jax array, which result into error. It might make sense to add layout only when KerasTensor is a subclass or jax array or tf.Tensor.

return distribution_lib.relayout(value, tensor_layout)


Expand Down
123 changes: 123 additions & 0 deletions keras/distribution/distribution_lib_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test for distribution_lib.py."""

import functools
import os
from unittest import mock

Expand Down Expand Up @@ -192,6 +193,15 @@ def test_get_variable_layout(self):
self.assertIs(variable_layout.device_mesh, self.device_mesh)
self.assertEqual(variable_layout.axes, (None,))

def test_get_tensor_layout(self):
distribution = distribution_lib.DataParallel(
device_mesh=self.device_mesh
)

path = "path/to/tensor"
tensor_layout = distribution.get_tensor_layout(path)
self.assertIsNone(tensor_layout)


class ModelParallelDistributionTest(testing.TestCase):
def setUp(self):
Expand Down Expand Up @@ -239,6 +249,22 @@ def test_distribute_data(self):
self.assertIs(data_layout.device_mesh, self.device_mesh)
self.assertEqual(data_layout.axes, ("data", None, None))

def test_get_tensor_layout(self):
layout_map = distribution_lib.LayoutMap(self.device_mesh)
layout_map[".*kernel"] = distribution_lib.TensorLayout([None, "model"])
layout_map[".*bias"] = distribution_lib.TensorLayout(["model"])
layout_map["/model/layer/tensor"] = ("data", None)

distribution = distribution_lib.ModelParallel(
self.device_mesh, layout_map, batch_dim_name="data"
)
layout = distribution.get_tensor_layout("/model/layer/tensor")
self.assertIs(layout.device_mesh, self.device_mesh)
self.assertEqual(layout.axes, ("data", None))

layout = distribution.get_tensor_layout("/model/layer/other_tensor")
self.assertIsNone(layout)


class LayoutMapTest(testing.TestCase):
def setUp(self):
Expand Down Expand Up @@ -362,6 +388,31 @@ def test_list_devices(self):
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)

def test_relayout(self):
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
)

inputs = jax.numpy.array(np.random.normal(size=(16, 8)))
original_layout = inputs.sharding
target_layout = jax.sharding.NamedSharding(
jax_mesh, jax.sharding.PartitionSpec("batch", None)
)

@functools.partial(jax.jit, static_argnames="target_layout")
def test_function(inputs, target_layout):
return distribution_lib.relayout(inputs, target_layout)

result = test_function(inputs, target_layout)
# Note that the returned tensor has a different sharding implementation
# which is GSPMDSharding, but it should be equivalent as the target
# layout specified.
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

# Test without jit
result = distribution_lib.relayout(inputs, target_layout)
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_to_jax_mesh(self):
devices = [f"cpu:{i}" for i in range(8)]
shape = (4, 2)
Expand Down Expand Up @@ -499,6 +550,78 @@ def test_e2e_model_parallel_model(self):
model.compile(loss="mse")
model.fit(inputs, labels)

def test_e2e_model_parallel_with_output_sharding(self):
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
shape, axis_names, backend_dlib.list_devices()
)

layout_map = distribution_lib.LayoutMap(device_mesh)
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
[None, "model"]
)
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])
# Force the dense layer output to be batch parallel only, and not
# sharded on model dimension.
layout_map[".*dense"] = ("batch", None)

distribution = distribution_lib.ModelParallel(
device_mesh, layout_map, batch_dim_name="batch"
)
sharding_capture = ShardingCaptureLayer()
with distribution.scope():
inputs = layers.Input(shape=[28, 28, 1])
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = sharding_capture(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = models.Model(inputs=inputs, outputs=y)

for weight in model.weights:
if "kernel" in weight.name:
self.assertEqual(weight._value.sharding.spec, (None, "model"))
elif "bias" in weight.name:
self.assertEqual(weight._value.sharding.spec, ("model",))
else:
self.assertTrue(weight._value.sharding.is_fully_replicated)

inputs = np.random.normal(size=(32, 28, 28, 1))
labels = np.random.normal(size=(32, 10))

with distribution.scope():
model.compile(loss="mse")
model.fit(inputs, labels)

# Note that the intermediate_tensor_layout is only captured during the
# actual training, and not at the model building time.
intermediate_tensor_layout = jax.sharding.NamedSharding(
backend_dlib._to_jax_mesh(distribution.device_mesh),
jax.sharding.PartitionSpec("batch", None),
)
self.assertTrue(
sharding_capture.captured_input_sharding.is_equivalent_to(
intermediate_tensor_layout, ndim=2
)
)


class ShardingCaptureLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.captured_input_sharding = None
self.supports_masking = True

def call(self, inputs):
jax.debug.inspect_array_sharding(
inputs, callback=lambda x: self.capture_input_sharding(x)
)
return inputs

def capture_input_sharding(self, sharding):
self.captured_input_sharding = sharding


# @pytest.mark.skipif(
# backend.backend() != "tensorflow",
Expand Down
Loading