Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update layout_map to use re.search instead of re.match. #18555

Merged
merged 4 commits into from
Oct 6, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address review comments
  • Loading branch information
qlzh727 committed Oct 6, 2023
commit 7c3e5406aaa6e4904d43916e5e79fc6cea3ff8d8
25 changes: 11 additions & 14 deletions keras/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,34 +444,31 @@ def __getitem__(self, key):

When there isn't an exact match, all the existing keys in the layout map
will be treated as a regex and map against the input key again. When
there are multiple matches for the regex, an ValueError will be raised.
Returns `None` if there isn't any match found.
there are multiple matches for the regex, an `ValueError` will be
raised. Returns `None` if there isn't any match found.

Args:
key: String key to query a layout.

Returns:
Corresponding layout based on the query.

Raises:
ValueError when multiple keys are matched if the keys are treated
as regex.
"""
if key in self._layout_map:
return self._layout_map[key]

matching_key = []
matching_keys = []
for k in self._layout_map:
if re.search(k, key):
matching_key.append(k)
if len(matching_key) > 1:
matching_keys.append(k)
if len(matching_keys) > 1:
raise ValueError(
f"The input {key} has matched to multiple layout "
f"rule: {matching_key}. Please make sure the "
"key only match to one rule."
f"Path '{key}' matches multiple layout "
f"specification keys: {matching_keys}. Please make "
"sure each tensor/variable path only matches at most "
"one layout specification key in the LayoutMap."
)
elif len(matching_key) == 1:
return self._layout_map[matching_key[0]]
elif len(matching_keys) == 1:
return self._layout_map[matching_keys[0]]
return None

def __setitem__(self, key, layout):
Expand Down
Loading