Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sharding constraint for the output tensor in the model #18536

Merged
merged 11 commits into from
Oct 6, 2023

Conversation

qlzh727
Copy link
Member

@qlzh727 qlzh727 commented Oct 2, 2023

This will be used for sharding the intermediate state (eg activations).

Updates to the API:

  1. keras.distribute.relayout for relayout/set sharding constraint for a tensor value.
  2. LayoutMap are used to support layer name as the key for the output layout of the layer.

Unit test has been updated for JAX as a demonstration.

This should fix #18521

@codecov-commenter
Copy link

codecov-commenter commented Oct 2, 2023

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

Comparison is base (25e4fa6) 78.00% compared to head (2fc70fc) 78.03%.
Report is 3 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18536      +/-   ##
==========================================
+ Coverage   78.00%   78.03%   +0.02%     
==========================================
  Files         334      334              
  Lines       32351    32406      +55     
  Branches     6313     6322       +9     
==========================================
+ Hits        25237    25287      +50     
- Misses       5546     5548       +2     
- Partials     1568     1571       +3     
Flag Coverage Δ
keras 77.93% <93.93%> (+0.02%) ⬆️
keras-jax 63.41% <93.93%> (-0.01%) ⬇️
keras-numpy 57.49% <42.42%> (-0.04%) ⬇️
keras-tensorflow 63.38% <33.33%> (-0.05%) ⬇️
keras-torch 64.26% <33.33%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
keras/backend/jax/core.py 89.13% <100.00%> (ø)
keras/backend/jax/trainer.py 95.51% <100.00%> (ø)
keras/distribution/distribution_lib.py 95.37% <100.00%> (+0.31%) ⬆️
keras/layers/layer.py 88.40% <100.00%> (+0.16%) ⬆️
keras/backend/jax/distribution_lib.py 88.57% <81.81%> (-4.29%) ⬇️

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Two points of discussion to finalize the API.

@@ -43,6 +43,13 @@ def distribute_value(value, tensor_layout):
return jax.device_put(value, tensor_layout)


def relayout(value, tensor_layout):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now the distinction between "distribute_value" and "relayout" is not clear -- according to the docstrings, both are about setting a layout on a tensor. I wonder if we could use function names that make the difference clearer. When would you use one and when would you use the other?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I think the major difference here is one of them is suppose to work in jitted function, and the other one works in eager. Currently I don't think we have a good way to detect what mode is user code in, and auto choose the proper API for them.

I think the jax with_sharding_constraint has a good indication that it applies the constraint to a intermediate tensor/state within the function. Maybe we should just aligned with that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: what's the difference between with_sharding_constraint applied to the input tensor (i.e. the data array), vs calling device_put on the input tensor? Right now we use device_put for data distribution. Does with_sharding_constraint work for that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I have a some chat in the jax user group, and the conclusion is that the with_sharding_constraint is designed to be used only in jitted function. It might not work in a lot of cases outside of jax.jit.

I think we should rely on the device_put in the pure eager context, for input data, as well as variable initialization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this, I think:

  • Have a separate distribute_variable and distribute_tensor API. Apparently this may be needed for TF?
  • In JAX distribute_tensor, check if we're in a tracing context. If so, use sharding_constraint. If not, use device_put.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I have added at check for distribute_tensor for whether its in the jitted context, and I think it might not be a cheap check if we have to do it very often. I also didn't find a proper way to do this kind of check google/jax#9241. We might want to check with Jax team for this.

if isinstance(value, KerasTensor):
# keras tensor is only used for building functional model, and can't be
# used to alter layout/sharding.
return value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should KerasTensors still have a layout attribute? In case we need to read it on the tensor? Or is that not useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question. The issue I hit when using kerasTensor with the jax sharding API is that it will always try to convert the KerasTensor to jax array, which result into error. It might make sense to add layout only when KerasTensor is a subclass or jax array or tf.Tensor.

distribution = distribution_lib.distribution()
if distribution is not None:
current_layer_path = current_path()
layout = distribution.get_tensor_layout(current_layer_path)
Copy link
Member

@fchollet fchollet Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for setting the layout of the output of a Dense layer in a subclassed model, you'd do like layout_map["model/layers/dense_1"] = (...)?

  1. Is there a risk of confusing variable layouts and intermediate tensor layouts?
  2. Should we be more specific, e.g. layout_map["model/layers/dense_1/output"] = (...) ? This could also leave the door open for input if ever needed.
  3. Is the full path too much information? What about layout_map["dense_1/output"] = (...)? Is that confusing?

Copy link
Member Author

@qlzh727 qlzh727 Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for setting the layout of the output of a Dense layer in a subclassed model, you'd do like layout_map["model/layers/dense_1"] = (...)?

Correct, its based on the path/name scope of the subclass model.

  1. Is there a risk of confusing variable layouts and intermediate tensor layouts?

It does. And my original intent was actually same as option 2.

  1. Should we be more specific, e.g. layout_map["model/layers/dense_1/output"] = (...) ? This could also leave the door open for input if ever needed.

That will definitely make the it more explicit. And also, this open the option for mapping to any intermediate keras operations within the layer body.

  1. Is the full path too much information? What about layout_map["dense_1/output"] = (...)? Is that confusing?

Matt has the same question, and he was proposing the use regex.search instead of regex.match so that user can skip the prefix. My original implementation was trying to be a bit strict, so that the layout won't accidentally map to unwanted weights. In the case that there are overlapping rule that apply to same weights, currently the first one wins. Maybe we can take the regex.match approach, and raise an error when multiple rules is mapped to the same weights/tensor. (Probably I will do this in a separate PR.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can take the regex.match approach, and raise an error when multiple rules is mapped to the same weights/tensor

I think that's a good idea, we can use search and then make sure that each variable matches at most 1 rule.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, let me do this in a separate PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qlzh727
Copy link
Member Author

qlzh727 commented Oct 3, 2023

@mattdangerw

@qlzh727 qlzh727 requested a review from fchollet October 3, 2023 20:37
@qlzh727
Copy link
Member Author

qlzh727 commented Oct 5, 2023

PTAL again.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates!

return jax.device_put(value, tensor_layout)

# TODO(scottzhu): This might not be a cheap check, we should consider
# have some proper JAX API for doing this check.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I consulted mattjj and that is not something they are considering.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. Thanks for the confirmation.

# have some proper JAX API for doing this check.
if jax_utils.is_in_jax_tracing_scope():
return jax.lax.with_sharding_constraint(tensor, tensor_layout)
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove the indent block and just do return

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -27,8 +29,33 @@ def list_devices(device_type=None):
return [f"{device.device_kind}:{device.id}" for device in jax_devices]


def distribute_value(value, tensor_layout):
"""Distribute the value based on the layout.
def distribute_variable(value, tensor_layout):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit weird to have distribute_variable(value, tensor_layout) if we also have distribute_tensor(value, tensor_layout). I suggest switching to distribute_variable(value, layout) and distribute_tensor(value, layout)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def distribute_variable(value, tensor_layout):
"""Create a distributed variable for JAX.

Since JAX doesn't have variable class, this will just return a jax.Array
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a variable class

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def distribute_variable(value, tensor_layout):
"""Create a distributed variable for JAX.

Since JAX doesn't have variable class, this will just return a jax.Array
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use backticks for code keywords.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"""Change the layout of a Tensor value in the jit function execution.

Note that this might not work outside of the jitted function for certain
backend. To change the layout of a value eagerly, please use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work in both situations in JAX, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works for JAX, but might not for all the other backend, eg tf.dtensor.relayout() will only work with dtensor instance and also in a tf.function, and pytorch is also unknown for now.

I think we will update the docstring once we have all the backend implemented, and reflect the final behavior.

@qlzh727 qlzh727 requested a review from fchollet October 6, 2023 17:20
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- Thank you!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 6, 2023
@fchollet fchollet merged commit c57e454 into keras-team:master Oct 6, 2023
7 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Oct 6, 2023
@qlzh727 qlzh727 deleted the sharding_constraint branch October 10, 2023 23:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

[Distribution] Support sharding for intermediate tensor within the model
4 participants