diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py index c829aa948f..06f5b0f601 100644 --- a/keras_nlp/models/gemma/gemma_backbone.py +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -194,7 +194,11 @@ def presets(cls): return copy.deepcopy(backbone_presets) @staticmethod - def get_layout_map(device_mesh, model_parallel_dim_name="model"): + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): """Get a `keras.distribution.LayoutMap` for model parallel distribution. The returned `LayoutMap` contains the sharding spec for the gemma @@ -221,6 +225,8 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"): distribution. model_parallel_dim_name: The axis name of the device mesh, where the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. Return: `keras.distribution.LayoutMap` that contains the sharding spec of all the model weights. @@ -248,21 +254,30 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"): f"{model_parallel_dim_name} is not found in the " f"device_mesh.axis_names. {device_mesh.axis_name=}" ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name model_dim = model_parallel_dim_name - # The sharding is partition for the hidden_dim of the model. + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 layout_map = keras.distribution.LayoutMap(device_mesh) - layout_map["token_embedding/embeddings"] = (None, model_dim) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ( - None, model_dim, + data_dim, None, ) layout_map["decoder_block.*attention_output.*kernel"] = ( - None, - None, model_dim, + None, + data_dim, ) - layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None) - layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim) + layout_map["decoder_block.*ffw_gating.*kernel"] = (data_dim, model_dim) + layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, data_dim) return layout_map diff --git a/keras_nlp/models/gemma/gemma_backbone_test.py b/keras_nlp/models/gemma/gemma_backbone_test.py index 855d49658b..7b02de2b7a 100644 --- a/keras_nlp/models/gemma/gemma_backbone_test.py +++ b/keras_nlp/models/gemma/gemma_backbone_test.py @@ -106,26 +106,34 @@ def test_distribution(self): for w in model.weights: if "token_embedding/embeddings" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) if "attention/query/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/key/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/value/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/attention_output/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, None, "model") + tuple(w.value.sharding.spec), ("model", None, "batch") ) if "ffw_gating/kernel" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) if "ffw_gating_2/kernel" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) if "ffw_linearl" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + )