Skip to content

Commit

Permalink
structuredconfig for train.py and eval.py (#1051)
Browse files Browse the repository at this point in the history
* first commit for structuredconfig for train.py

* revamp configs

* wip latest issue

* reorder so mandatory attributes come first

* fix

* fix

* fix fix

* fix types

* fix dictconfig

* fix union of list|dict configs

* fix type annotation

* oops

* fixed configs

* add save ignore keys

* fix batch size kerfuffle

* fix dictconfig stuff

* fix dictconfig stuff again

* fix

* fix

* updated unit tests for variables

* last fix?

* if this test case does not pass I will venmo Mihir 0

* remove a 'not' -- eg. 'I am not going crazy'

* Update scripts/train/train.py

Co-authored-by: Daniel King <[email protected]>

* set amp bf16 as default precision, etc

* temporarily wrap with dictconfig before ** migration

* fix icl tasks

* fix

* fix activation checkpointing reentrant

* fix extraneous keys

* first round **

* fix?

* quick fsdp config fix

* updated yamls to make variables explicit

* remove precision from mandatory params list

* I expect many of these to fail in interesting ways

* fix test_model test cases with **

* fix many more test cases

* fix dictconfig objectification

* fix remaining  test cases

* remove unneeded **

* fix test case

* changed back argument name

* fix

* ** for finetuning dataloader

* fix?

* fix dataloader

* fix

* fix finetuning dataloader

* fix build_text_dataloader

* left to my own devices

* fix packing

* fix typo

* fix padding test cases

* ignore extra parameters and warn

* fix style

* fix quality checks

* fix code quality

* pyright-fu

* fix

* just one more type constraint bro

* OmegaConf -> om

* rename variables for clarity

* revert file

* revert file II

* revert file III: revert of the sith

* peft revert file

* revert v_mpt

* last revert

* remove redundant checks

* deprecate

* make cleaner

* pyright is bullying me again

* further clean config_utils

* polish train

* polish train and eval

* fix dist

* fix style

* organize eval and train

* fix

* used helper function to make main cleaner

* fix stuff

* fix pyright

* added fix and explanation

* fix typo in unit test update smh

* Update llmfoundry/registry.py

Co-authored-by: Daniel King <[email protected]>

* Update scripts/train/train.py

Co-authored-by: Daniel King <[email protected]>

* Update scripts/train/train.py

Co-authored-by: Daniel King <[email protected]>

* Update scripts/train/train.py

Co-authored-by: Daniel King <[email protected]>

* Apply suggestions from code review

Co-authored-by: Daniel King <[email protected]>

* see if this fails

* reject name and device rather than ignoring

* pretrained is not a bool

* add validation to make sure the user doesn't set both

* forbid config keys

* oops forgot eval

* address coomments

* removed redundant check

* updated callsites not to use name

* fix

* validate extraneous keys in dataloader

* fix

* fix more

* fix III: revenge of the fix

* fix IV: a new hope

* fix V: the empire fixes back

* fixed some more types

* fix VI: return of the fix

* fix VII: the fix awakens

* fix VIII: the last bug

* fix

* final fix I think

* fixed

* fix style

* fix

* fix fix

* fix fix style

* icl task config

* fix train

* fix finetuning dataloader

* fix train types

* fix token counting

* fix train types

* oopsie

* fix straggler issues

* fix tests

* fix???

* fix hf v mpt gpu test and fmapi test

* pop device

* to_str_dict -> to_dict_recursive

* fix this darn unit test one more time

* fix ComposerMPTCausalLM constructor invocation

* Delete tests/models/hf/test_hf_fsdp.py

* unwrap model in unit tests

* model.model.model.model.model

* abstract away dataclass construction

* updated docstrings and removed dictconfig from logging logic

* flag icl tasks required or not

* updated a couple yamls

* updated train and eval scripts

* un-delete global train batch size

* fix

* I don't understand why this doesn't work

* that was the sneakiest bug I've ever fixed

* try to fix the regression test

* remove device train grad accum

* fix validate config

* removed unused import

* use variables

* missing mandatory value fix

* use correct type of error

* fix

* import TrainConfig just in case?

* moved trainconfig and evalconfig into utils

* works

* no cheating

* dicts everywhere gah

* try no recursive just

* rename typed helpers

* fix the test cases with deep magic

* towards a peaceful resolution

* remove comments

* fix type warnings

* Update llmfoundry/utils/config_utils.py

Co-authored-by: Daniel King <[email protected]>

* address low-hanging fruit

* remove peft wrapping extra model

* python 🤝 haskell

* dataset config should be dict

* just because omega starts with OMMMM does not mean it's zen

* fix

* fix

* structured settlement

* precision further down

* throws TypeError instead of MissingMandatoryValue or whatever

* remove debugging statement

* remove to_container calls everywhere

* wrap then unwrap

* pyright

* error early on missing mandatory values

* remove unnecessory ignore

* update unit tests

* update eval yamls

* Update train.py

* make log level optional again

* oopsie

* use keywords for arg clarity

* use keywords for arg clarity

* style

* style

* dist timeout

* resolve deeper conflict issues

* fix train.py

* fix registry

* fix dataloader

* fix train II

* fix dataloader and utils

* fix dictconfig

* skill issue

* add new keys

* remove pop_config

* fix

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
milocress and dakinggg authored May 8, 2024
1 parent a777014 commit cc8351c
Show file tree
Hide file tree
Showing 69 changed files with 1,888 additions and 1,614 deletions.
14 changes: 7 additions & 7 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

"""Dataloader builder utilities."""

from typing import Union
from typing import Any, Dict

from composer import DataSpec
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

from llmfoundry import registry
Expand All @@ -18,9 +17,9 @@


def build_dataloader(
cfg: DictConfig,
cfg: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
device_batch_size: Union[int, float],
device_batch_size: int,
) -> DataSpec:
"""Builds a dataloader from a config.
Expand All @@ -30,14 +29,15 @@ def build_dataloader(
device_batch_size (int): The size of the batches (number of examples)
that the dataloader will produce.
"""
kwargs = {
'cfg': cfg,
name = cfg.pop('name')
kwargs: Dict[str, Any] = {
**cfg,
'tokenizer': tokenizer,
'device_batch_size': device_batch_size,
}

return construct_from_registry(
name=cfg.name,
name=name,
registry=registry.dataloaders,
partial_function=False,
pre_validation_function=None,
Expand Down
349 changes: 214 additions & 135 deletions llmfoundry/data/finetuning/dataloader.py

Large diffs are not rendered by default.

63 changes: 32 additions & 31 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

import logging
import tempfile
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple

import numpy as np
import torch
from composer.utils import dist
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -318,7 +317,7 @@ def pad_tensor(tensor: torch.Tensor, pad_value: int):


def auto_packing_ratio(
dataloader_cfg: DictConfig,
dataloader_cfg: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
num_packing_ratios: int = 20,
Expand Down Expand Up @@ -352,20 +351,21 @@ def auto_packing_ratio(
# Set the seed so that auto packing is deterministic.
reproducibility.seed_all(0)

max_seq_len = dataloader_cfg.dataset.max_seq_len
# If max_seq_len is very small, skip profiling and select packing ratio of 1.
dataset_config = dataloader_cfg['dataset']
max_seq_len = dataset_config.get('max_seq_len')
if max_seq_len <= 100:
return 1

min_ratio = 1
max_ratio = max_seq_len / 100
profiling_results = profile_packing(
dataloader_cfg,
tokenizer,
min_ratio,
max_ratio,
num_packing_ratios,
device_batch_size,
dataloader_cfg=dataloader_cfg,
tokenizer=tokenizer,
min_ratio=min_ratio,
max_ratio=max_ratio,
num_packing_ratios=num_packing_ratios,
device_batch_size=device_batch_size,
)

# Obtain the maximum packing_ratio/minimum padding that has no waste.
Expand All @@ -392,7 +392,7 @@ def auto_packing_ratio(


def profile_packing(
dataloader_cfg: DictConfig,
dataloader_cfg: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
min_ratio: float,
max_ratio: float,
Expand All @@ -416,39 +416,40 @@ def profile_packing(

from llmfoundry.data.dataloader import build_dataloader

max_seq_len = dataloader_cfg.dataset.get('max_seq_len')
max_leftovers_to_keep = dataloader_cfg.dataset.get(
'max_leftovers_to_keep',
None,
)
dataset_cfg = dataloader_cfg['dataset']
max_seq_len = dataset_cfg.get('max_seq_len')
max_leftovers_to_keep = dataset_cfg.get('max_leftovers_to_keep', None)

# Turn off packing and sequence parallelism for the dataloader (we want raw, pre-packed, full-length examples)
dataloader_cfg = copy.deepcopy(dataloader_cfg)
dataloader_cfg.dataset.packing_ratio = 1.0
dataloader_cfg.dataset.auto_packing_replication = dataloader_cfg.dataset.get(
'seq_parallel_replication',
1,
) or 1
dataloader_cfg.dataset.seq_parallel_replication = 1
dataloader_cfg.drop_last = False
dataloader_cfg.num_workers = 0
dataloader_cfg.prefetch_factor = None
dataloader_cfg.persistent_workers = False
dataloader_cfg.update({
'drop_last': False,
'num_workers': 0,
'prefetch_factor': None,
'persistent_workers': False,
})
dataloader_cfg['dataset']['packing_ratio'] = 1.0
dataloader_cfg['dataset']['auto_packing_replication'
] = dataloader_cfg['dataset'].get(
'seq_parallel_replication',
1,
) or 1
dataloader_cfg['dataset']['seq_parallel_replication'] = 1

# If streaming dataset, use a temporary local folder for profiling
local_rank_zero = dist.get_global_rank() - dist.get_local_rank()
if dataloader_cfg.dataset.get('remote') is not None:
if dataloader_cfg['dataset'].get('remote') is not None:
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
dataloader_cfg.dataset.local = tmp_path
dataloader_cfg['dataset']['local'] = tmp_path

if dataloader_cfg.dataset.get('streams') is not None:
for stream_config in dataloader_cfg.dataset.streams.values():
if dataloader_cfg['dataset'].get('streams') is not None:
for stream_config in dataloader_cfg['dataset']['streams'].values():
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
stream_config.local = tmp_path
stream_config['local'] = tmp_path

# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
Expand Down
80 changes: 49 additions & 31 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import numpy as np
import torch
from composer.core.data_spec import DataSpec
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from streaming import Stream, StreamingDataset
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -268,79 +266,96 @@ def get_sequence_id_from_batch(
return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)


def build_streams(dataset_cfg: DictConfig):
streams_dict = dataset_cfg.pop('streams', None)
def build_streams(streams: Optional[Dict[str, Any]] = None,):
streams_dict = streams
# build streams
streams = None
streams_ret = []
if streams_dict is not None:
streams = []
for stream in streams_dict.values():
# stream is the streams kwargs
# fwd all kwargs with **stream allows streaming to check args
streams.append(Stream(**stream))
return streams
streams_ret = [Stream(**stream) for stream in streams_dict.values()]
return streams_ret


def build_text_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: Union[int, float],
device_batch_size: int,
dataset: Dict[str, Any],
drop_last: bool,
num_workers: int,
pin_memory: bool = True,
prefetch_factor: int = 2,
persistent_workers: bool = True,
timeout: int = 0,
) -> DataSpec:
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'

dataset_cfg = dataset

# get kwargs
cfg.dataset['replication'], dataset_batch_size = construct_from_registry(
dataset_cfg['replication'], dataset_batch_size = construct_from_registry(
name='dataset_replication_validator',
registry=registry.dataset_replication_validators,
partial_function=False,
kwargs={
'cfg': cfg,
'dataset_cfg': dataset_cfg,
'tokenizer': tokenizer,
'device_batch_size': device_batch_size,
},
)

streams = build_streams(cfg.dataset)
streams = build_streams(
streams=dataset_cfg.pop('streams')
if 'streams' in dataset_cfg else None,
)

valid_streaming_text_dataset_parameters = inspect.signature(
StreamingTextDataset,
).parameters

dataset_config_subset_for_streaming_text_dataset = {
k: v
for k, v in cfg.dataset.items()
for k, v in dataset_cfg.items()
if k in valid_streaming_text_dataset_parameters
}

# build dataset potentially with streams
dataset = StreamingTextDataset(
text_dataset = StreamingTextDataset(
tokenizer=tokenizer,
streams=streams,
batch_size=dataset_batch_size,
**dataset_config_subset_for_streaming_text_dataset,
)

dataloader_cfg = {
'name': 'text',
'dataset': dataset_cfg,
'drop_last': drop_last,
'num_workers': num_workers,
'pin_memory': pin_memory,
'prefetch_factor': prefetch_factor,
'persistent_workers': persistent_workers,
'timeout': timeout,
}

collate_fn, dataloader_batch_size = construct_from_registry(
name='text_collator',
registry=registry.collators,
partial_function=False,
kwargs={
'cfg': cfg,
'tokenizer': dataset.tokenizer,
'dataloader_cfg': dataloader_cfg,
'tokenizer': tokenizer,
'dataset_batch_size': dataset_batch_size,
},
)

dl = DataLoader(
dataset,
text_dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
drop_last=drop_last,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
timeout=timeout,
)

return construct_from_registry(
Expand All @@ -349,7 +364,7 @@ def build_text_dataloader(
partial_function=False,
kwargs={
'dl': dl,
'dataset_cfg': cfg.dataset,
'dataset_cfg': dataset_cfg,
},
)

Expand Down Expand Up @@ -415,14 +430,17 @@ def build_text_dataloader(
'drop_last': False,
'num_workers': 4,
}
cfg = om.create(cfg)
device_batch_size = 2

tokenizer_name = args.tokenizer
tokenizer_kwargs = {'model_max_length': args.max_seq_len}
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

loader = build_text_dataloader(cfg, tokenizer, device_batch_size).dataloader
loader = build_text_dataloader(
**cfg,
tokenizer=tokenizer,
device_batch_size=device_batch_size,
).dataloader
assert isinstance(loader, DataLoader)
assert isinstance(loader.dataset, StreamingTextDataset)
tokenizer = loader.dataset.tokenizer
Expand Down
Loading

0 comments on commit cc8351c

Please sign in to comment.