Skip to content

Commit

Permalink
fix up pos embed
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Jul 26, 2024
1 parent fc02dc0 commit 41466c0
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions maze_transformer/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ class BaseGPTConfig(SerializableDataclass):
d_model: int
d_head: int
n_layers: int
positional_embedding_type: str
positional_embedding_type: str = serializable_field(
default="standard",
loading_fn=lambda data: data.get("positional_embedding_type", "standard"),
)

weight_processing: dict[str, bool] = serializable_field(
default_factory=lambda: dict(
Expand Down Expand Up @@ -277,23 +280,20 @@ def summary(self) -> dict:
d_model=32,
d_head=16,
n_layers=4,
positional_embedding_type='standard',
),
BaseGPTConfig(
name="tuned-v1",
act_fn="gelu",
d_model=384,
d_head=64,
n_layers=6,
positional_embedding_type='standard',
),
BaseGPTConfig(
name="gpt2-small",
act_fn="gelu",
d_model=384, # half of gpt2-small
d_head=64, # match gpt-2 small
n_layers=12, # half of gpt2-small
positional_embedding_type='standard',
),
# this one is just for integration tests
BaseGPTConfig(
Expand All @@ -302,7 +302,6 @@ def summary(self) -> dict:
d_model=8,
d_head=4,
n_layers=2,
positional_embedding_type='standard',
),
]

Expand Down

0 comments on commit 41466c0

Please sign in to comment.