-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sharding constraint for the output tensor in the model #18536
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #18536 +/- ##
==========================================
+ Coverage 78.00% 78.03% +0.02%
==========================================
Files 334 334
Lines 32351 32406 +55
Branches 6313 6322 +9
==========================================
+ Hits 25237 25287 +50
- Misses 5546 5548 +2
- Partials 1568 1571 +3
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Two points of discussion to finalize the API.
@@ -43,6 +43,13 @@ def distribute_value(value, tensor_layout): | |||
return jax.device_put(value, tensor_layout) | |||
|
|||
|
|||
def relayout(value, tensor_layout): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now the distinction between "distribute_value" and "relayout" is not clear -- according to the docstrings, both are about setting a layout on a tensor. I wonder if we could use function names that make the difference clearer. When would you use one and when would you use the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. I think the major difference here is one of them is suppose to work in jitted function, and the other one works in eager. Currently I don't think we have a good way to detect what mode is user code in, and auto choose the proper API for them.
I think the jax with_sharding_constraint
has a good indication that it applies the constraint to a intermediate tensor/state within the function. Maybe we should just aligned with that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: what's the difference between with_sharding_constraint
applied to the input tensor (i.e. the data array), vs calling device_put
on the input tensor? Right now we use device_put
for data distribution. Does with_sharding_constraint
work for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I have a some chat in the jax user group, and the conclusion is that the with_sharding_constraint
is designed to be used only in jitted function. It might not work in a lot of cases outside of jax.jit.
I think we should rely on the device_put in the pure eager context, for input data, as well as variable initialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do this, I think:
- Have a separate
distribute_variable
anddistribute_tensor
API. Apparently this may be needed for TF? - In JAX
distribute_tensor
, check if we're in a tracing context. If so, usesharding_constraint
. If not, usedevice_put
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I have added at check for distribute_tensor for whether its in the jitted context, and I think it might not be a cheap check if we have to do it very often. I also didn't find a proper way to do this kind of check google/jax#9241. We might want to check with Jax team for this.
if isinstance(value, KerasTensor): | ||
# keras tensor is only used for building functional model, and can't be | ||
# used to alter layout/sharding. | ||
return value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should KerasTensors still have a layout
attribute? In case we need to read it on the tensor? Or is that not useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question. The issue I hit when using kerasTensor with the jax sharding API is that it will always try to convert the KerasTensor to jax array, which result into error. It might make sense to add layout only when KerasTensor is a subclass or jax array or tf.Tensor.
distribution = distribution_lib.distribution() | ||
if distribution is not None: | ||
current_layer_path = current_path() | ||
layout = distribution.get_tensor_layout(current_layer_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So for setting the layout of the output of a Dense
layer in a subclassed model, you'd do like layout_map["model/layers/dense_1"] = (...)
?
- Is there a risk of confusing variable layouts and intermediate tensor layouts?
- Should we be more specific, e.g.
layout_map["model/layers/dense_1/output"] = (...)
? This could also leave the door open forinput
if ever needed. - Is the full path too much information? What about
layout_map["dense_1/output"] = (...)
? Is that confusing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So for setting the layout of the output of a
Dense
layer in a subclassed model, you'd do likelayout_map["model/layers/dense_1"] = (...)
?
Correct, its based on the path/name scope of the subclass model.
- Is there a risk of confusing variable layouts and intermediate tensor layouts?
It does. And my original intent was actually same as option 2.
- Should we be more specific, e.g.
layout_map["model/layers/dense_1/output"] = (...)
? This could also leave the door open forinput
if ever needed.
That will definitely make the it more explicit. And also, this open the option for mapping to any intermediate keras operations within the layer body.
- Is the full path too much information? What about
layout_map["dense_1/output"] = (...)
? Is that confusing?
Matt has the same question, and he was proposing the use regex.search instead of regex.match so that user can skip the prefix. My original implementation was trying to be a bit strict, so that the layout won't accidentally map to unwanted weights. In the case that there are overlapping rule that apply to same weights, currently the first one wins. Maybe we can take the regex.match approach, and raise an error when multiple rules is mapped to the same weights/tensor. (Probably I will do this in a separate PR.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can take the regex.match approach, and raise an error when multiple rules is mapped to the same weights/tensor
I think that's a good idea, we can use search
and then make sure that each variable matches at most 1 rule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack, let me do this in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PTAL again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates!
return jax.device_put(value, tensor_layout) | ||
|
||
# TODO(scottzhu): This might not be a cheap check, we should consider | ||
# have some proper JAX API for doing this check. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I consulted mattjj and that is not something they are considering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack. Thanks for the confirmation.
# have some proper JAX API for doing this check. | ||
if jax_utils.is_in_jax_tracing_scope(): | ||
return jax.lax.with_sharding_constraint(tensor, tensor_layout) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove the indent block and just do return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -27,8 +29,33 @@ def list_devices(device_type=None): | |||
return [f"{device.device_kind}:{device.id}" for device in jax_devices] | |||
|
|||
|
|||
def distribute_value(value, tensor_layout): | |||
"""Distribute the value based on the layout. | |||
def distribute_variable(value, tensor_layout): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a bit weird to have distribute_variable(value, tensor_layout)
if we also have distribute_tensor(value, tensor_layout)
. I suggest switching to distribute_variable(value, layout)
and distribute_tensor(value, layout)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def distribute_variable(value, tensor_layout): | ||
"""Create a distributed variable for JAX. | ||
|
||
Since JAX doesn't have variable class, this will just return a jax.Array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a variable class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def distribute_variable(value, tensor_layout): | ||
"""Create a distributed variable for JAX. | ||
|
||
Since JAX doesn't have variable class, this will just return a jax.Array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use backticks for code keywords.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
"""Change the layout of a Tensor value in the jit function execution. | ||
|
||
Note that this might not work outside of the jitted function for certain | ||
backend. To change the layout of a value eagerly, please use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should work in both situations in JAX, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works for JAX, but might not for all the other backend, eg tf.dtensor.relayout() will only work with dtensor instance and also in a tf.function, and pytorch is also unknown for now.
I think we will update the docstring once we have all the backend implemented, and reflect the final behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM -- Thank you!
This will be used for sharding the intermediate state (eg activations).
Updates to the API:
keras.distribute.relayout
for relayout/set sharding constraint for a tensor value.Unit test has been updated for JAX as a demonstration.
This should fix #18521