-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
general dependency housekeeping (#218)
* added test for py 3.12, torch 2.4.0 * poetry update * update deps - no longer support torch < 2.0, testing only on torch 2.4.0 - update muutils, zanj, transformer-lens deps to latest versions (this might break things. we will see) * fix @Freeze * added test for TRAIN_SAVE_FILES frozen * run format * fix minor pytorch warning * fix old hash keys * update for upstream maze-dataset fixes to tokenizer * fix duplicate 3.10 run in CI * update hash keys in notebooks * re-run and fix notebooks * fix nb * re-run nb, minor fixes to saving data
- Loading branch information
Showing
12 changed files
with
1,269 additions
and
1,139 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,48 @@ | ||
from datetime import datetime | ||
from typing import Callable | ||
|
||
from muutils.misc import freeze, sanitize_fname # type: ignore[import] | ||
from muutils.misc import sanitize_fname # type: ignore[import] | ||
|
||
from maze_transformer.training.config import ConfigHolder | ||
|
||
|
||
@freeze | ||
class TRAIN_SAVE_FILES: | ||
class _TRAIN_SAVE_FILES: | ||
"""namespace for filenames/formats for saving training data""" | ||
|
||
# old | ||
data_cfg: str = "data_config.json" | ||
train_cfg: str = "train_config.json" | ||
model_checkpt: Callable[[int], str] = lambda iteration: f"model.iter_{iteration}.pt" | ||
model_checkpt: Callable[[int], str] = ( | ||
lambda _, iteration: f"model.iter_{iteration}.pt" | ||
) | ||
model_final: str = "model.final.pt" | ||
|
||
# keep these | ||
config_holder: str = "config.json" | ||
checkpoints: str = "checkpoints" | ||
log: str = "log.jsonl" | ||
model_checkpt_zanj: Callable[[int], str] = ( | ||
lambda iteration: f"model.iter_{iteration}.zanj" | ||
lambda _, iteration: f"model.iter_{iteration}.zanj" | ||
) | ||
model_final_zanj: str = "model.final.zanj" | ||
model_run_dir: Callable[[ConfigHolder], str] = ( | ||
lambda cfg: f"{sanitize_fname(cfg.name)}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" | ||
lambda _, cfg: f"{sanitize_fname(cfg.name)}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" | ||
) | ||
|
||
@classmethod | ||
def __class_getitem__(cls, _): | ||
return cls | ||
|
||
def __class_getattribute__(cls, name): | ||
if name.startswith("__"): | ||
return super().__class_getattribute__(name) | ||
attr = cls.__dict__[name] | ||
return attr | ||
|
||
def __setattr__(self, name, value): | ||
raise AttributeError("TRAIN_SAVE_FILES is read-only") | ||
|
||
__delattr__ = __setattr__ | ||
|
||
|
||
TRAIN_SAVE_FILES = _TRAIN_SAVE_FILES() |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Oops, something went wrong.