Skip to content

Commit

Permalink
Merge branch 'main' into daniels-exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored May 24, 2024
2 parents 3bc11f7 + 1e4bd37 commit a65e869
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 17 deletions.
29 changes: 23 additions & 6 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,30 @@ def attach_ffn_mb_args(
ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group


def get_fsdp_submesh_2d(device_mesh: DeviceMesh):
"""Get the submesh for FSDP.
Args:
device_mesh (DeviceMesh): The full device mesh.
Returns:
DeviceMesh: The submesh for FSDP.
"""
if device_mesh.mesh.ndim == 2:
submesh = device_mesh['weight_parallel']
elif device_mesh.mesh.ndim == 3:
raise RuntimeError(f'HSDP + MoE is not supported.')
else:
raise ValueError(f'{device_mesh.mesh.ndim=} not supported for MoE.')

return submesh


def set_ffn_device_mesh(
ffn: nn.Module,
moe_world_size: int,
device_mesh: DeviceMesh,
get_fsdp_submesh: Callable[[DeviceMesh], DeviceMesh],
):
"""Sets the device mesh in FSDP kwargs.
Expand Down Expand Up @@ -413,12 +433,7 @@ def set_ffn_device_mesh(
for name, dtensorified_param in dtensorified_params:
ffn.experts.mlp.register_parameter(name, dtensorified_param)

if device_mesh.mesh.ndim == 2:
submesh = device_mesh['weight_parallel']
elif device_mesh.mesh.ndim == 3:
raise RuntimeError(f'HSDP + MoE is not supported.')
else:
raise ValueError(f'{device_mesh.mesh.ndim=} not supported for MoE.')
submesh = get_fsdp_submesh(device_mesh)

ffn.experts._fsdp_kwargs_dict = {
'device_mesh': submesh,
Expand Down Expand Up @@ -470,6 +485,7 @@ def build_mb_moe(
ffn=ffn,
moe_world_size=moe_world_size,
device_mesh=kwargs['device_mesh'],
get_fsdp_submesh=get_fsdp_submesh_2d,
)

return ffn
Expand Down Expand Up @@ -536,6 +552,7 @@ def build_mb_dmoe(
ffn=ffn,
moe_world_size=moe_world_size,
device_mesh=kwargs['device_mesh'],
get_fsdp_submesh=get_fsdp_submesh_2d,
)

return ffn
Expand Down
7 changes: 5 additions & 2 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
'update_batch_size_info',
'process_init_device',
'log_config',
'log_dataset_uri',
]


Expand Down Expand Up @@ -508,7 +509,6 @@ def log_config(cfg: Dict[str, Any]) -> None:

if 'mlflow' in loggers and mlflow.active_run():
mlflow.log_params(params=cfg)
_log_dataset_uri(cfg)


def _parse_source_dataset(cfg: Dict[str, Any]) -> List[Tuple[str, str, str]]:
Expand Down Expand Up @@ -619,12 +619,15 @@ def _process_data_source(
log.warning('DataSource Not Found.')


def _log_dataset_uri(cfg: Dict[str, Any]) -> None:
def log_dataset_uri(cfg: Dict[str, Any]) -> None:
"""Logs dataset tracking information to MLflow.
Args:
cfg (DictConfig): A config dictionary of a run
"""
loggers = cfg.get('loggers', None) or {}
if 'mlflow' not in loggers or not mlflow.active_run():
return
# Figure out which data source to use
data_paths = _parse_source_dataset(cfg)

Expand Down
29 changes: 26 additions & 3 deletions llmfoundry/utils/data_prep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import List, Optional

from composer.utils import ObjectStore
from composer.utils.object_store import ObjectStoreTransientError
from composer.utils.retrying import retry

__all__ = [
'merge_shard_groups',
Expand Down Expand Up @@ -78,6 +80,26 @@ def merge_shard_groups(root: str) -> None:
out.write(text)


@retry(ObjectStoreTransientError, num_attempts=5)
def download_file(
object_store: ObjectStore,
object_name: str,
output_filename: str,
) -> None:
"""Downloads a file from an object store.
Args:
object_store (ObjectStore): Object store to download from
object_name (str): Name of object to download
output_filename (str): Local filename to write to
"""
object_store.download_object(
object_name=object_name,
filename=output_filename,
overwrite=True,
)


class DownloadingIterable:

def __init__(
Expand Down Expand Up @@ -110,10 +132,11 @@ def __iter__(self):
self.output_folder,
object_name.strip('/'),
)
self.object_store.download_object(

download_file(
object_store=self.object_store,
object_name=object_name,
filename=output_filename,
overwrite=True,
output_filename=output_filename,
)

with open(output_filename) as _txt_file:
Expand Down
11 changes: 8 additions & 3 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from llmfoundry.utils import maybe_create_mosaicml_logger
from llmfoundry.utils.data_prep_utils import (
DownloadingIterable,
download_file,
merge_shard_groups,
)
from llmfoundry.utils.exceptions import (
Expand Down Expand Up @@ -329,9 +330,13 @@ def is_already_processed(
try:
with tempfile.TemporaryDirectory() as tmp_dir:
done_file = os.path.join(tmp_dir, DONE_FILENAME)
output_object_store.download_object(
os.path.join(output_folder_prefix, DONE_FILENAME),
done_file,
download_file(
object_store=output_object_store,
object_name=os.path.join(
output_folder_prefix,
DONE_FILENAME,
),
output_filename=done_file,
)
with open(done_file) as df:
done_file_contents = df.read().splitlines()
Expand Down
2 changes: 2 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
TRAIN_CONFIG_KEYS,
TrainConfig,
log_config,
log_dataset_uri,
make_dataclass_and_log_config,
pop_config,
process_init_device,
Expand Down Expand Up @@ -530,6 +531,7 @@ def main(cfg: DictConfig) -> Trainer:
if train_cfg.log_config:
log.info('Logging config')
log_config(logged_cfg)
log_dataset_uri(logged_cfg)
torch.cuda.empty_cache()
gc.collect()

Expand Down
8 changes: 5 additions & 3 deletions tests/utils/test_mlflow_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import pytest

from llmfoundry.utils.config_utils import (
_log_dataset_uri,
_parse_source_dataset,
log_dataset_uri,
)

mlflow = pytest.importorskip('mlflow')
Expand Down Expand Up @@ -84,10 +84,12 @@ def test_log_dataset_uri():
}},
source_dataset_train='huggingface/train_dataset',
source_dataset_eval='huggingface/eval_dataset',
loggers={'mlflow': {}},
)

with patch('mlflow.log_input') as mock_log_input:
_log_dataset_uri(cfg)
with patch('mlflow.log_input') as mock_log_input, \
patch('mlflow.active_run', return_value=True):
log_dataset_uri(cfg)
assert mock_log_input.call_count == 2
meta_dataset_calls = [
args[0] for args, _ in mock_log_input.call_args_list
Expand Down

0 comments on commit a65e869

Please sign in to comment.