Skip to content

Commit

Permalink
Merge branch 'mosaicml:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Jan 29, 2024
2 parents 624a339 + 34cdaf6 commit 1c25b98
Show file tree
Hide file tree
Showing 31 changed files with 776 additions and 204 deletions.
2 changes: 1 addition & 1 deletion .ci/FILE_HEADER
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Copyright 2022 MosaicML LLM Foundry authors
Copyright 2024 MosaicML LLM Foundry authors
SPDX-License-Identifier: Apache-2.0
2 changes: 1 addition & 1 deletion .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
branches:
- main
paths:
- ./Dockerfile
- Dockerfile
- .github/workflows/docker.yaml
workflow_dispatch: {}
jobs:
Expand Down
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.3.1
rev: v1.5.4
hooks:
- id: insert-license
args:
- --license-filepath
- .ci/FILE_HEADER
- --comment-style
- '#'
- --allow-past-years
types: [python]
- repo: https://github.com/PyCQA/docformatter
rev: v1.5.0
Expand Down
78 changes: 59 additions & 19 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from composer.utils import dist
from composer.utils.misc import create_interval_scheduler

from mcli import ComputeConfig, Run, RunConfig, create_run, get_run
from mcli import Run, RunConfig, create_run, get_run

log = logging.getLogger(__name__)

Expand All @@ -33,7 +33,9 @@
OPTIONAL_PARAMS_FOR_EVAL = {
'dist_timeout',
'eval_gauntlet',
'eval_loader',
'fsdp_config',
'eval_subset_num_batches',
'icl_subset_num_batches',
'loggers',
'precision',
Expand Down Expand Up @@ -175,50 +177,84 @@ def validate_interval(interval: Union[str, int, Time],
return async_interval


def validate_eval_run_config(
eval_run_config: Optional[Dict[str, Any]]) -> Dict[str, Any]:

if not eval_run_config:
return {}

run_config = eval_run_config.copy()

supported_keys = {'image', 'command', 'compute', 'scheduling'}
found_unsupported = set()
for key in run_config:
if key not in supported_keys:
found_unsupported.add(key)

if found_unsupported:
raise ValueError(
f'Unsupported eval run config keys found: {", ".join(found_unsupported)}'
+ f'. Supported keys: {supported_keys}')

return run_config


class AsyncEval(Callback):
"""Run the eval loop asynchronously as part of a MosaicML platform run.
This callback is currently experimental. The API may change in the future.
Args:
training_config: Dict[str, Any]: The config from the training run
training_params: Dict[str, Any]: The parameter config from the training run
interval: Union[str, int, Time]: The interval describing how often eval runs should be
launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
compute: Optional[Union[ComputeConfig, Dict[str, Any]]]: The compute configuration to
use for the eval run. If not provided, the same cluster as the current run and a
single, full GPU node will be used.
eval_run_config: Optional[Dict[str, Any]]: A subset of mcli run config values to use
for the eval run. If not specified, any fields from run config will be created
dynamically from the training run config and the interval. The following fields
are supported:
- ``image``: Image of the eval run. Default: same as training run
- ``command``: Command to run for the eval run. Default: calls
`composer scripts/eval/eval.py $PARAMETERS`. If custom setup is needed,
the command should include calling the eval script with $PARAMETERS
- ``compute``: Compute to use for the eval run. Default: same cluster as
the training run and a single node (8 GPUs)
- ``scheduling``: Scheduling to use for the eval run. Default: same as training run
All fields are optional, but if specified, must be valid for a mcli run config. We
provide this optional config to give you the most flexibility in customizing the eval
run, but it is recommended to use the default values unless you have a specific use case
"""

def __init__(
self,
training_config: Dict[str, Any],
training_params: Dict[str, Any],
interval: Union[str, int, Time],
compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None,
eval_run_config: Optional[Dict[str, Any]] = None,
):

for required in ('save_interval', 'save_folder'):
if required not in training_config:
if required not in training_params:
raise ValueError(f'{required} required for async eval')

self.checkpoint_save_folder = training_config['save_folder']
self.training_config = training_config
self.checkpoint_save_folder = training_params['save_folder']
self.training_params = training_params
self.eval_run_config = validate_eval_run_config(eval_run_config)
self.interval = validate_interval(interval,
self.training_config['save_interval'])
self.training_params['save_interval'])
self.check_interval = create_interval_scheduler(
interval,
# There is a custom close to ensure that the final checkpoint
# (which is the most important) is evaled after it is written
include_end_of_training=False,
)
self.compute = compute
self.last_checkpoint: Optional[str] = None

# Run these during init to fail fast in any of the error cases
self.current_run = self._get_current_run()
get_eval_parameters(
parameters=training_config,
parameters=training_params,
checkpoint='test',
training_run_name=self.current_run.name,
)
Expand Down Expand Up @@ -259,7 +295,7 @@ def close(self, state: State, logger: Logger) -> None:
if dist.get_global_rank() != 0:
return

save_latest_filename = self.training_config.get('save_latest_filename',
save_latest_filename = self.training_params.get('save_latest_filename',
None)

if not save_latest_filename:
Expand Down Expand Up @@ -297,7 +333,7 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
run_name = get_run_name(self.current_run.name, str(current_interval))

params = get_eval_parameters(
parameters=self.training_config,
parameters=self.training_params,
checkpoint=checkpoint,
training_run_name=self.current_run.name,
)
Expand Down Expand Up @@ -347,12 +383,16 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
# TODO: This just runs an eval run, but we also want to attach the
# deployment, which would require a hf conversion and parametrizing the
# dependent_deployment in the run config
command = f'cd {installation_path}/scripts \n composer eval/eval.py $PARAMETERS'
default_command = f'cd {installation_path}/scripts \n composer eval/eval.py $PARAMETERS'
run_config = RunConfig(
name=run_name,
image=self.current_run.image,
compute=self.compute or default_compute,
command=command,
image=self.eval_run_config.get('image', self.current_run.image),
command=self.eval_run_config.get('command', default_command),
compute=self.eval_run_config.get('compute', default_compute),
scheduling=self.eval_run_config.get(
'scheduling',
self.current_run.submitted_config.scheduling,
),
integrations=integrations,
env_variables=cfg.env_variables,
metadata=cfg.metadata,
Expand Down
27 changes: 27 additions & 0 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import math
import os
import re
import tempfile
from pathlib import Path
from typing import Optional, Sequence, Union
Expand All @@ -27,6 +28,23 @@

log = logging.getLogger(__name__)

_LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE)


def _maybe_get_license_filename(local_dir: str) -> Optional[str]:
"""Returns the name of the license file if it exists in the local_dir.
Note: This is intended to be consistent with the code in MLflow.
https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152
If the license file does not exist, returns None.
"""
try:
return next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))
except StopIteration:
return None


class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.
Expand Down Expand Up @@ -279,6 +297,15 @@ def _save_checkpoint(self, state: State, logger: Logger):
path=local_save_path,
**self.mlflow_logging_config,
)

license_filename = _maybe_get_license_filename(
local_save_path)
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
os.path.join(local_save_path, license_filename),
)

mlflow_logger.register_model(
model_uri=local_save_path,
name=self.mlflow_registered_model_name,
Expand Down
9 changes: 8 additions & 1 deletion llmfoundry/callbacks/scheduled_gc_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import gc
from typing import Optional

import torch
from composer.core import Callback, State
Expand All @@ -19,16 +20,19 @@ class ScheduledGarbageCollector(Callback):
"""Disable automatic garbage collection and collect garbage at interval.
Args:
batch_interval (int): Number of batches between checkpoints call to gc.collect()
batch_interval (int): Number of batches between calls to gc.collect()
gen_1_batch_interval(int, optional): Number of batches between calls to gc.collect(1)
eval_keep_disabled (bool): keep gc disabled during eval (default: False)
"""

def __init__(
self,
batch_interval: int,
gen_1_batch_interval: Optional[int] = None,
eval_keep_disabled: bool = False,
):
self.batch_interval = batch_interval
self.gen_1_batch_interval = gen_1_batch_interval
self.eval_keep_disabled = eval_keep_disabled
self.gc_init_state = None

Expand Down Expand Up @@ -56,6 +60,9 @@ def fit_end(self, state: State, logger: Logger) -> None:
def before_dataloader(self, state: State, logger: Logger) -> None:
del logger # unused

if self.gen_1_batch_interval is not None and state.timestamp.batch.value % self.gen_1_batch_interval == 0:
gc.collect(1)

if state.timestamp.batch.value % self.batch_interval == 0:
gc_cuda()

Expand Down
Loading

0 comments on commit 1c25b98

Please sign in to comment.