Skip to content

Commit

Permalink
update maze-dataset dep and poetry lockfile (#213)
Browse files Browse the repository at this point in the history
* update maze-dataset dep and poetry lockfile

* fixed some imports

* update another import

* another import fix

* hookedtransformer maze tokenizer compat

* update wandb dep

* update transformer_lens dep

* fix import in notebook

* re-ran this notebook. was failing in CI for some reason, idk?

* better error when dataset cfgs dont match

* (run format) - working locally, but configs differ in CI

```
ValueError: ('dataset has different config than cfg.dataset_cfg, and allow_dataset_override iscollect_generation_meta', 'args': (), 'kwargs': {} False', "{'applied_filters': {'self': [{'name': '}], 'other': []}}")
```

probably because we are loading a dataset with the new format. will patch but it will be hacky

* special case for applied filters diff
  • Loading branch information
mivanit authored May 14, 2024
1 parent 495d8d3 commit 20092c3
Show file tree
Hide file tree
Showing 10 changed files with 1,496 additions and 1,488 deletions.
2 changes: 1 addition & 1 deletion maze_transformer/evaluation/baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
get_origin_tokens,
get_path_tokens,
get_target_tokens,
strings_to_coords,
)
from maze_dataset.tokenization.util import strings_to_coords
from transformer_lens import HookedTransformer

from maze_transformer.training.config import ConfigHolder
Expand Down
6 changes: 3 additions & 3 deletions maze_transformer/evaluation/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
)
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization.token_utils import (
WhenMissing,
get_context_tokens,
get_path_tokens,
remove_padding_from_token_str,
strings_to_coords,
)
from maze_dataset.tokenization.util import strings_to_coords
from maze_dataset.utils import WhenMissing

# muutils
from muutils.mlutils import chunks
Expand Down Expand Up @@ -143,7 +143,7 @@ def predict_maze_paths(
smart_max_new_tokens
), "if max_new_tokens is None, smart_max_new_tokens must be True"

maze_tokenizer: MazeTokenizer = model.config.maze_tokenizer
maze_tokenizer: MazeTokenizer = model.tokenizer._maze_tokenizer

contexts_lists: list[list[str]] = [
get_context_tokens(tokens) for tokens in tokens_batch
Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/evaluation/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def plot_predicted_paths(
if n_mazes is None:
n_mazes = len(dataset)

dataset_tokens = dataset.as_tokens(model.config.maze_tokenizer)[:n_mazes]
dataset_tokens = dataset.as_tokens(model.tokenizer._maze_tokenizer)[:n_mazes]

# predict
predictions: list[list[str | tuple[int, int]]] = predict_maze_paths(
Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/mechinterp/plot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from maze_dataset.plotting.plot_tokens import plot_colored_text
from maze_dataset.plotting.print_tokens import color_tokens_cmap
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization.token_utils import coord_str_to_tuple_noneable
from maze_dataset.tokenization.util import coord_str_to_tuple_noneable

# Utilities
from muutils.json_serialize import SerializableDataclass, serializable_dataclass
Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/mechinterp/residual_stream_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# maze_dataset
from maze_dataset.constants import _SPECIAL_TOKENS_ABBREVIATIONS
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization.token_utils import strings_to_coords
from maze_dataset.tokenization.util import strings_to_coords

# scipy
from scipy.spatial.distance import pdist, squareform
Expand Down
26 changes: 23 additions & 3 deletions maze_transformer/training/train_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import typing
import warnings
from pathlib import Path
from typing import Union

Expand Down Expand Up @@ -122,9 +123,28 @@ def train_model(
f"passed dataset has different config than cfg.dataset_cfg, but allow_dataset_override is True, so using passed dataset"
)
else:
raise ValueError(
f"dataset has different config than cfg.dataset_cfg, and allow_dataset_override is False"
)
datasets_cfg_diff: dict = dataset.cfg.diff(cfg.dataset_cfg)
if datasets_cfg_diff == {
"applied_filters": {
"self": [
{
"name": "collect_generation_meta",
"args": (),
"kwargs": {},
}
],
"other": [],
}
}:
warnings.warn(
f"dataset has different config than cfg.dataset_cfg, but the only difference is in applied_filters, so using passed dataset. This is due to fast dataset loading collecting generation metadata for performance reasons"
)

else:
raise ValueError(
f"dataset has different config than cfg.dataset_cfg, and allow_dataset_override is False",
f"{datasets_cfg_diff = }",
)

logger.progress(f"finished getting training dataset with {len(dataset)} samples")
# validation dataset, if applicable
Expand Down
16 changes: 8 additions & 8 deletions notebooks/residual_stream_decoding.ipynb

Large diffs are not rendered by default.

197 changes: 87 additions & 110 deletions notebooks/train_model.ipynb

Large diffs are not rendered by default.

2,724 changes: 1,367 additions & 1,357 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ repository = "https://github.com/understanding-search/maze-transformer"
[tool.poetry.dependencies]
python = ">=3.10,<3.13"
# dataset
maze-dataset = "^0.4.5"
maze-dataset = "^0.5.2"
# transformers
torch = ">=1.13.1"
transformer-lens = "1.14.0"
transformer-lens = "^1.14.0"
transformers = ">=4.34" # Dependency in transformer-lens 1.14.0
# utils
muutils = "^0.5.5"
zanj = "^0.2.0"
wandb = "^0.13.5" # note: TransformerLens forces us to use 0.13.5
# wandb = "^0.13.5" # note: TransformerLens forces us to use 0.13.5
wandb = "^0.17.0"
fire = "^0.5.0"
typing-extensions = "^4.8.0"
# plotting
Expand Down

0 comments on commit 20092c3

Please sign in to comment.