diff --git a/keras/distribution/distribution_lib.py b/keras/distribution/distribution_lib.py index 37fbe280c58..8425f9da3a6 100644 --- a/keras/distribution/distribution_lib.py +++ b/keras/distribution/distribution_lib.py @@ -327,10 +327,10 @@ class ModelParallel(Distribution): # will be split across 4 devices. Any other variable that doesn't # match any key in the layout map will be fully replicated. layout_map = LayoutMap(device_mesh) - layout_map['.*dense.*kernel'] = (None, 'model') - layout_map['.*dense.*bias'] = ('model',) - layout_map['.*conv2d.*kernel'] = (None, None, None, 'model') - layout_map['.*conv2d.*bias'] = ('model',) + layout_map['dense.*kernel'] = (None, 'model') + layout_map['dense.*bias'] = ('model',) + layout_map['conv2d.*kernel'] = (None, None, None, 'model') + layout_map['conv2d.*bias'] = ('model',) distribution = ModelParallel(device_mesh=device_mesh, layout_map=layout_map, @@ -415,10 +415,10 @@ class LayoutMap(collections.abc.MutableMapping): ```python layout_map = LayoutMap(device_mesh=None) - layout_map['.*dense.*kernel'] = (None, 'model') # layout_2d - layout_map['.*dense.*bias'] = ('model',) # layout_1d - layout_map['.*conv2d.*kernel'] = TensorLayout((None, None, None, 'model')) - layout_map['.*conv2d.*bias'] = TensorLayout(('model',)) # layout_1d + layout_map['dense.*kernel'] = (None, 'model') # layout_2d + layout_map['dense.*bias'] = ('model',) # layout_1d + layout_map['conv2d.*kernel'] = TensorLayout((None, None, None, 'model')) + layout_map['conv2d.*bias'] = TensorLayout(('model',)) # layout_1d layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d @@ -443,9 +443,9 @@ def __getitem__(self, key): """Retrieves the corresponding layout by the string key. When there isn't an exact match, all the existing keys in the layout map - will be treated as a regex and map against the input key again. The - first match will be returned, based on the key insertion order. Returns - `None` if there isn't any match found. + will be treated as a regex and map against the input key again. When + there are multiple matches for the regex, an `ValueError` will be + raised. Returns `None` if there isn't any match found. Args: key: String key to query a layout. @@ -456,9 +456,19 @@ def __getitem__(self, key): if key in self._layout_map: return self._layout_map[key] + matching_keys = [] for k in self._layout_map: - if re.match(k, key): - return self._layout_map[k] + if re.search(k, key): + matching_keys.append(k) + if len(matching_keys) > 1: + raise ValueError( + f"Path '{key}' matches multiple layout " + f"specification keys: {matching_keys}. Please make " + "sure each tensor/variable path only matches at most " + "one layout specification key in the LayoutMap." + ) + elif len(matching_keys) == 1: + return self._layout_map[matching_keys[0]] return None def __setitem__(self, key, layout): diff --git a/keras/distribution/distribution_lib_test.py b/keras/distribution/distribution_lib_test.py index a45ce3c0c89..c19814f45c1 100644 --- a/keras/distribution/distribution_lib_test.py +++ b/keras/distribution/distribution_lib_test.py @@ -293,15 +293,18 @@ def test_get(self): layout_map["dense.*kernel"] = self.replicated_2d layout_map["dense.*bias"] = self.replicated_1d - layout_map[".*bias"] = self.sharded_1d + layout_map["bias"] = self.sharded_1d self.assertEqual(layout_map["dense/kernel"], self.sharded_2d) self.assertEqual(layout_map["dense/bias"], self.sharded_1d) - # Map against the wildcard bias rule for dense, and based on the order - # of insertion, it will not use .*bias. self.assertEqual(layout_map["dense_2/kernel"], self.replicated_2d) - self.assertEqual(layout_map["dense_2/bias"], self.replicated_1d) + # Map against the wildcard bias rule for dense. This will cause a + # ValueError + with self.assertRaisesRegex( + ValueError, "Path 'dense_2/bias' matches multiple layout" + ): + layout_map["dense_2/bias"] self.assertIsNone(layout_map["conv2d/kernel"]) self.assertEqual(layout_map["conv2d/bias"], self.sharded_1d)