Skip to content

Commit

Permalink
general dependency housekeeping (#218)
Browse files Browse the repository at this point in the history
* added test for py 3.12, torch 2.4.0

* poetry update

* update deps

- no longer support torch < 2.0, testing only on torch 2.4.0
- update muutils, zanj, transformer-lens deps to latest versions
  (this might break things. we will see)

* fix @Freeze

* added test for TRAIN_SAVE_FILES frozen

* run format

* fix minor pytorch warning

* fix old hash keys

* update for upstream maze-dataset fixes to tokenizer

* fix duplicate 3.10 run in CI

* update hash keys in notebooks

* re-run and fix notebooks

* fix nb

* re-run nb, minor fixes to saving data
  • Loading branch information
mivanit authored Jul 26, 2024
1 parent 6a2c993 commit 45348cb
Show file tree
Hide file tree
Showing 12 changed files with 1,269 additions and 1,139 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ jobs:
matrix:
versions:
- python: "3.10"
torch: "1.13.1"
- python: "3.10"
torch: "2.0.1"
torch: "2.4.0"
- python: "3.11"
torch: "2.0.1"
torch: "2.4.0"
- python: "3.12"
torch: "2.4.0"
steps:
- name: Checkout code
uses: actions/checkout@v3
Expand Down
42 changes: 30 additions & 12 deletions maze_transformer/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@
from maze_transformer.tokenizer import HuggingMazeTokenizer


# TODO: replace with muutils
def dynamic_docstring(**doc_params):
def decorator(func):
if func.__doc__:
func.__doc__ = func.__doc__.format(**doc_params)
return func

return decorator


@serializable_dataclass(kw_only=True, properties_to_serialize=["n_heads"])
class BaseGPTConfig(SerializableDataclass):
"""
Expand Down Expand Up @@ -528,6 +538,11 @@ def create_model_zanj(self) -> ZanjHookedTransformer:
return ZanjHookedTransformer(self)

@classmethod
@dynamic_docstring(
dataset_cfg_names=str(list(MAZE_DATASET_CONFIGS.keys())),
model_cfg_names=str(list(GPT_CONFIGS.keys())),
train_cfg_names=str(list(TRAINING_CONFIGS.keys())),
)
def get_config_multisource(
cls,
cfg: ConfigHolder | None = None,
Expand All @@ -536,18 +551,14 @@ def get_config_multisource(
kwargs_in: dict | None = None,
) -> ConfigHolder:
"""pass one of cfg object, file, or list of names. Any kwargs will be applied to the config object (and should start with 'cfg.')
cfg_names should be either `(dataset_cfg_name,model_cfg_name,train_cfg_name)` or the same with collective name at the end
valid name keys:
- dataset_cfg_name: {dataset_cfg_names}
- model_cfg_name: {model_cfg_names}
- train_cfg_name: {train_cfg_names}
""".format(
dataset_cfg_names=str(list(MAZE_DATASET_CONFIGS.keys())),
model_cfg_names=str(list(GPT_CONFIGS.keys())),
train_cfg_names=str(list(TRAINING_CONFIGS.keys())),
)
"""

config: ConfigHolder
assert (
Expand All @@ -573,12 +584,19 @@ def get_config_multisource(
name = f"multsrc_{dataset_cfg_name}_{model_cfg_name}_{train_cfg_name}"
else:
dataset_cfg_name, model_cfg_name, train_cfg_name, name = cfg_names
config = ConfigHolder(
name=name,
dataset_cfg=MAZE_DATASET_CONFIGS[dataset_cfg_name],
model_cfg=GPT_CONFIGS[model_cfg_name],
train_cfg=TRAINING_CONFIGS[train_cfg_name],
)
try:
config = ConfigHolder(
name=name,
dataset_cfg=MAZE_DATASET_CONFIGS[dataset_cfg_name],
model_cfg=GPT_CONFIGS[model_cfg_name],
train_cfg=TRAINING_CONFIGS[train_cfg_name],
)
except KeyError as e:
raise KeyError(
"tried to get a config that doesn't exist, check the names.\n",
f"{dataset_cfg_name = }, {model_cfg_name = }, {train_cfg_name = }\n",
ConfigHolder.get_config_multisource.__doc__,
) from e

else:
raise ValueError(
Expand Down
31 changes: 25 additions & 6 deletions maze_transformer/training/train_save_files.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,48 @@
from datetime import datetime
from typing import Callable

from muutils.misc import freeze, sanitize_fname # type: ignore[import]
from muutils.misc import sanitize_fname # type: ignore[import]

from maze_transformer.training.config import ConfigHolder


@freeze
class TRAIN_SAVE_FILES:
class _TRAIN_SAVE_FILES:
"""namespace for filenames/formats for saving training data"""

# old
data_cfg: str = "data_config.json"
train_cfg: str = "train_config.json"
model_checkpt: Callable[[int], str] = lambda iteration: f"model.iter_{iteration}.pt"
model_checkpt: Callable[[int], str] = (
lambda _, iteration: f"model.iter_{iteration}.pt"
)
model_final: str = "model.final.pt"

# keep these
config_holder: str = "config.json"
checkpoints: str = "checkpoints"
log: str = "log.jsonl"
model_checkpt_zanj: Callable[[int], str] = (
lambda iteration: f"model.iter_{iteration}.zanj"
lambda _, iteration: f"model.iter_{iteration}.zanj"
)
model_final_zanj: str = "model.final.zanj"
model_run_dir: Callable[[ConfigHolder], str] = (
lambda cfg: f"{sanitize_fname(cfg.name)}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
lambda _, cfg: f"{sanitize_fname(cfg.name)}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
)

@classmethod
def __class_getitem__(cls, _):
return cls

def __class_getattribute__(cls, name):
if name.startswith("__"):
return super().__class_getattribute__(name)
attr = cls.__dict__[name]
return attr

def __setattr__(self, name, value):
raise AttributeError("TRAIN_SAVE_FILES is read-only")

__delattr__ = __setattr__


TRAIN_SAVE_FILES = _TRAIN_SAVE_FILES()
70 changes: 32 additions & 38 deletions notebooks/demo_dataset.ipynb

Large diffs are not rendered by default.

198 changes: 99 additions & 99 deletions notebooks/eval_tasks_table.ipynb

Large diffs are not rendered by default.

126 changes: 82 additions & 44 deletions notebooks/train_model.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 45348cb

Please sign in to comment.