Skip to content

Commit

Permalink
Merge branch 'main' into milo/unify-peft-codepath
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress authored Dec 5, 2024
2 parents 14d7209 + 16f92ef commit e2d8ecb
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 10 deletions.
7 changes: 5 additions & 2 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,9 @@ def profile_packing(

# If streaming dataset, use a temporary local folder for profiling
local_rank_zero = dist.get_global_rank() - dist.get_local_rank()
if dataset_cfg.get('remote') is not None:
if dataset_cfg.get(
'remote',
) is not None and dataset_cfg.get('local') is None:
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
Expand All @@ -485,7 +487,8 @@ def profile_packing(
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
stream_config['local'] = tmp_path
if stream_config.get('local') is None:
stream_config['local'] = tmp_path

# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
'mlflow>=2.14.1,<2.19',
'accelerate>=0.25,<1.2', # for HF inference `device_map`
'transformers>=4.43.2,<4.47',
'mosaicml-streaming>=0.9.0,<0.10',
'mosaicml-streaming>=0.10.0,<0.11',
'torch>=2.5.1,<2.5.2',
'datasets>=2.20.0,<2.21',
'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data
Expand Down
65 changes: 58 additions & 7 deletions tests/data/test_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from typing import Any
from typing import Any, Callable
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -161,27 +161,73 @@ def test_dist_auto_packing(profile_packing: Mock):
assert packing_ratio == 2


def get_remote_config(
base_cfg: dict,
remote_dir: str,
local_dir: str,
) -> DictConfig:
return DictConfig({
**base_cfg,
'dataset': {
**base_cfg['dataset'],
'remote': remote_dir,
'local': local_dir,
},
})


def get_streams_config(
base_cfg: dict,
remote_dir: str,
local_dir: str,
) -> DictConfig:
return DictConfig({
**base_cfg,
'dataset': {
**base_cfg['dataset'],
'streams': {
'stream_with_remote': {
'remote': remote_dir,
'local': local_dir,
},
'stream_without_remote': {
'local': remote_dir,
},
},
},
})


def patched_packing_ratio(*args: Any, **kwargs: Any):
from llmfoundry.data.packing import auto_packing_ratio

return auto_packing_ratio(*args, **kwargs, num_packing_ratios=4)


@pytest.mark.parametrize(
'get_config',
[
get_remote_config,
get_streams_config,
],
)
@patch(
'llmfoundry.data.finetuning.dataloader.auto_packing_ratio',
patched_packing_ratio,
)
def test_auto_packing_with_streaming_dataloader(tmp_path: Path):
def test_auto_packing_with_streaming_dataloader(
get_config: Callable[[dict, str, str], DictConfig],
tmp_path: Path,
):
columns = {'prompt': 'str', 'response': 'str'}
tokenizer = build_tokenizer('gpt2', {})
remote_dir = str(tmp_path / 'remote')
local_dir = str(tmp_path / 'local')
with MDSWriter(out=remote_dir, columns=columns, compression=None) as out:
out.write({'prompt': 'HELLO', 'response': 'WORLD'})
cfg = DictConfig({

base_cfg = {
'dataset': {
'remote': remote_dir,
'local': local_dir,
'packing_ratio': 'auto',
'max_seq_len': 200,
'decoder_only_format': True,
Expand All @@ -194,7 +240,9 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path):
'prefetch_factor': None,
'persistent_workers': False,
'timeout': 0,
})
}

cfg = get_config(base_cfg, remote_dir, local_dir)

loader = build_finetuning_dataloader(
**cfg,
Expand All @@ -214,7 +262,10 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path):
assert isinstance(loader.batch_size, int)
assert loader.dataset.packing_ratio == int(loader.batch_size / 6)

state_dict = loader.dataset.state_dict(num_samples=2, from_beginning=False)
state_dict = loader.dataset.state_dict(
num_samples=2,
from_beginning=False,
)
assert state_dict['sample_in_epoch'] == 2 * loader.dataset.packing_ratio


Expand Down

0 comments on commit e2d8ecb

Please sign in to comment.