Skip to content

Commit

Permalink
Modify configs to allow specification of pos encoding type
Browse files Browse the repository at this point in the history
  • Loading branch information
afspies committed Dec 27, 2023
1 parent b3417f9 commit fc02dc0
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions maze_transformer/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class BaseGPTConfig(SerializableDataclass):
d_model: int
d_head: int
n_layers: int
positional_embedding_type: str

weight_processing: dict[str, bool] = serializable_field(
default_factory=lambda: dict(
Expand All @@ -59,6 +60,7 @@ def summary(self) -> dict:
d_model=self.d_model,
d_head=self.d_head,
n_layers=self.n_layers,
positional_embedding_type=self.positional_embedding_type,
weight_processing=self.weight_processing,
n_heads=self.n_heads,
)
Expand Down Expand Up @@ -275,20 +277,23 @@ def summary(self) -> dict:
d_model=32,
d_head=16,
n_layers=4,
positional_embedding_type='standard',
),
BaseGPTConfig(
name="tuned-v1",
act_fn="gelu",
d_model=384,
d_head=64,
n_layers=6,
positional_embedding_type='standard',
),
BaseGPTConfig(
name="gpt2-small",
act_fn="gelu",
d_model=384, # half of gpt2-small
d_head=64, # match gpt-2 small
n_layers=12, # half of gpt2-small
positional_embedding_type='standard',
),
# this one is just for integration tests
BaseGPTConfig(
Expand All @@ -297,6 +302,7 @@ def summary(self) -> dict:
d_model=8,
d_head=4,
n_layers=2,
positional_embedding_type='standard',
),
]

Expand Down Expand Up @@ -496,6 +502,7 @@ def hooked_transformer_cfg(self) -> HookedTransformerConfig:
d_model=self.model_cfg.d_model,
d_head=self.model_cfg.d_head,
n_layers=self.model_cfg.n_layers,
positional_embedding_type=self.model_cfg.positional_embedding_type,
n_ctx=self.dataset_cfg.seq_len_max,
d_vocab=self.maze_tokenizer.vocab_size,
)
Expand Down

0 comments on commit fc02dc0

Please sign in to comment.