diff --git a/maze_transformer/training/config.py b/maze_transformer/training/config.py index aa48a6cb..25c6b1ec 100644 --- a/maze_transformer/training/config.py +++ b/maze_transformer/training/config.py @@ -39,6 +39,10 @@ class BaseGPTConfig(SerializableDataclass): d_model: int d_head: int n_layers: int + positional_embedding_type: str = serializable_field( + default="standard", + loading_fn=lambda data: data.get("positional_embedding_type", "standard"), + ) weight_processing: dict[str, bool] = serializable_field( default_factory=lambda: dict( @@ -59,6 +63,7 @@ def summary(self) -> dict: d_model=self.d_model, d_head=self.d_head, n_layers=self.n_layers, + positional_embedding_type=self.positional_embedding_type, weight_processing=self.weight_processing, n_heads=self.n_heads, ) @@ -501,6 +506,7 @@ def hooked_transformer_cfg(self) -> HookedTransformerConfig: d_model=self.model_cfg.d_model, d_head=self.model_cfg.d_head, n_layers=self.model_cfg.n_layers, + positional_embedding_type=self.model_cfg.positional_embedding_type, n_ctx=self.dataset_cfg.seq_len_max, d_vocab=self.maze_tokenizer.vocab_size, ) diff --git a/tests/unit/maze_transformer/training/config/test_base_gpt_config.py b/tests/unit/maze_transformer/training/config/test_base_gpt_config.py index 84b70497..e88245e1 100644 --- a/tests/unit/maze_transformer/training/config/test_base_gpt_config.py +++ b/tests/unit/maze_transformer/training/config/test_base_gpt_config.py @@ -63,6 +63,7 @@ def _custom_serialized_config(): "d_head": 1, "n_layers": 1, "n_heads": 1, + "positional_embedding_type": "standard", "weight_processing": { "are_layernorms_folded": False, "are_weights_processed": False,