Skip to content

Commit

Permalink
Merge branch 'main' into rm_compile_glu
Browse files Browse the repository at this point in the history
  • Loading branch information
josejg authored Apr 9, 2024
2 parents 1bae81b + 2939cc9 commit c9270dc
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 1 deletion.
2 changes: 2 additions & 0 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def build_finetuning_dataloader(cfg: DictConfig,
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
batching_method=cfg.dataset.get('batching_method', 'random'),
max_seq_len=cfg.dataset.max_seq_len,
allow_unsafe_types=cfg.dataset.get('allow_unsafe_types', False),
replication=cfg.dataset.get('replication', None),
)

else:
Expand Down
10 changes: 10 additions & 0 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,12 @@ class StreamingFinetuningDataset(StreamingDataset):
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an error
if ``False``. Defaults to ``False``.
replication (int, optional): Determines how many consecutive devices will receive the same
samples. Useful for training with tensor or sequence parallelism, where multiple
devices need to see the same partition of the dataset. Defaults to ``None``.
"""

def __init__(self,
Expand All @@ -516,6 +522,8 @@ def __init__(self,
sampling_granularity: int = 1,
batching_method: str = 'random',
max_seq_len: int = 2048,
allow_unsafe_types: bool = False,
replication: Optional[int] = None,
**kwargs: Any):

if len(kwargs) > 0:
Expand Down Expand Up @@ -552,6 +560,8 @@ def __init__(self,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method,
allow_unsafe_types=allow_unsafe_types,
replication=replication,
)

self.tokenizer = tokenizer
Expand Down
10 changes: 10 additions & 0 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ class StreamingTextDataset(StreamingDataset):
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an error
if ``False``. Defaults to ``False``.
replication (int, optional): Determines how many consecutive devices will receive the same
samples. Useful for training with tensor or sequence parallelism, where multiple
devices need to see the same partition of the dataset. Defaults to ``None``.
"""

def __init__(self,
Expand All @@ -109,6 +115,8 @@ def __init__(self,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
allow_unsafe_types: bool = False,
replication: Optional[int] = None,
**kwargs: Any):

if len(kwargs) > 0:
Expand Down Expand Up @@ -151,6 +159,8 @@ def __init__(self,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method,
allow_unsafe_types=allow_unsafe_types,
replication=replication,
)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def build_eval_loaders(
# Load the eval data to fail fast. metrics will get added
# later in add_metrics_to_eval_loaders, after the model is loaded
metric_names=[],
device_eval_microbatch_size=device_eval_batch_size,
)
evaluators.append(eval_loader)
return evaluators
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
'mlflow>=2.10,<3',
'accelerate>=0.25,<0.26', # for HF inference `device_map`
'transformers>=4.39.3,<4.40',
'mosaicml-streaming>=0.7.4,<0.8',
'mosaicml-streaming>=0.7.5,<0.8',
'torch>=2.2.1,<2.3',
'datasets>=2.16,<2.17',
'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data
Expand Down

0 comments on commit c9270dc

Please sign in to comment.