Skip to content

Commit

Permalink
Add line splitting and other linting
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed May 2, 2024
1 parent 5f39606 commit b1463a1
Show file tree
Hide file tree
Showing 155 changed files with 11,509 additions and 6,924 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
default_language_version:
python: python3
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.2.2
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
Expand Down
31 changes: 25 additions & 6 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,38 @@

# Filter out Hugging Face warning for not using a pinned revision of the model
hf_dynamic_modules_logger = logging.getLogger(
'transformers.dynamic_module_utils')
'transformers.dynamic_module_utils',
)
new_files_warning_filter = SpecificWarningFilter(
'A new version of the following files was downloaded from')
'A new version of the following files was downloaded from',
)

hf_dynamic_modules_logger.addFilter(new_files_warning_filter)

from llmfoundry import (algorithms, callbacks, cli, data, eval, interfaces,
loggers, metrics, models, optim, tokenizers, utils)
from llmfoundry import (
algorithms,
callbacks,
cli,
data,
eval,
interfaces,
loggers,
metrics,
models,
optim,
tokenizers,
utils,
)
from llmfoundry.data import StreamingFinetuningDataset, StreamingTextDataset
from llmfoundry.eval import InContextLearningDataset, InContextLearningMetric
from llmfoundry.models.hf import ComposerHFCausalLM
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
MPTForCausalLM, MPTModel, MPTPreTrainedModel)
from llmfoundry.models.mpt import (
ComposerMPTCausalLM,
MPTConfig,
MPTForCausalLM,
MPTModel,
MPTPreTrainedModel,
)
from llmfoundry.optim import DecoupledLionW

__all__ = [
Expand Down
8 changes: 6 additions & 2 deletions llmfoundry/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from composer.algorithms import (Alibi, GatedLinearUnits, GradientClipping,
LowPrecisionLayerNorm)
from composer.algorithms import (
Alibi,
GatedLinearUnits,
GradientClipping,
LowPrecisionLayerNorm,
)

from llmfoundry.registry import algorithms

Expand Down
22 changes: 16 additions & 6 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from composer.callbacks import (EarlyStopper, EvalOutputLogging, Generate,
LRMonitor, MemoryMonitor, MemorySnapshot,
OOMObserver, OptimizerMonitor, RuntimeEstimator,
SpeedMonitor)
from composer.callbacks import (
EarlyStopper,
EvalOutputLogging,
Generate,
LRMonitor,
MemoryMonitor,
MemorySnapshot,
OOMObserver,
OptimizerMonitor,
RuntimeEstimator,
SpeedMonitor,
)

from llmfoundry.callbacks.async_eval_callback import AsyncEval
from llmfoundry.callbacks.curriculum_learning_callback import CurriculumLearning
Expand All @@ -15,8 +23,10 @@
MegaBlocksMoE_TokPerExpert
from llmfoundry.callbacks.monolithic_ckpt_callback import \
MonolithicCheckpointSaver
from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
LayerFreezing)
from llmfoundry.callbacks.resumption_callbacks import (
GlobalLRScaling,
LayerFreezing,
)
from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector
from llmfoundry.registry import callbacks, callbacks_with_config

Expand Down
106 changes: 67 additions & 39 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from composer.callbacks import CheckpointSaver
from composer.core import Event, State, Time, Timestamp, TimeUnit
from composer.loggers import Logger
from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR,
RUN_NAME_ENV_VAR)
from composer.loggers.mosaicml_logger import (
MOSAICML_PLATFORM_ENV_VAR,
RUN_NAME_ENV_VAR,
)
from composer.utils import dist
from composer.utils.file_helpers import list_remote_objects
from composer.utils.misc import create_interval_scheduler
Expand Down Expand Up @@ -65,15 +67,17 @@ def get_run_name(training_run_name: str, current_interval: str) -> str:
"""
name_without_uuid_suffix = training_run_name.rsplit('-', 1)[0]

max_length = MAX_RUN_NAME_BASE_LENGTH - len(RUN_NAME_PREFIX) - len(
current_interval) - 2
max_length = MAX_RUN_NAME_BASE_LENGTH - len(
RUN_NAME_PREFIX,
) - len(current_interval) - 2

# A run name that is too long will fail a createRun call
if len(name_without_uuid_suffix) > max_length:
new_name = name_without_uuid_suffix[:max_length]
log.warning(
f'Training run name {name_without_uuid_suffix} may be too long,' +
f' truncating to {new_name}')
f' truncating to {new_name}',
)
name_without_uuid_suffix = new_name

return f'{RUN_NAME_PREFIX}-{current_interval}-{name_without_uuid_suffix}'
Expand Down Expand Up @@ -107,7 +111,7 @@ def get_eval_parameters(

if looking_for:
raise Exception(
f'Missing the following required parameters for async eval: {looking_for}'
f'Missing the following required parameters for async eval: {looking_for}',
)

for logger, config in subset_keys.get('loggers', {}).items():
Expand All @@ -126,7 +130,7 @@ def get_eval_parameters(
new_models = {
'model_name': model_name,
'model': model,
'load_path': checkpoint
'load_path': checkpoint,
}

tokenizer = subset_keys.pop('tokenizer', None)
Expand All @@ -136,27 +140,32 @@ def get_eval_parameters(
return subset_keys


def validate_interval(interval: Union[str, int, Time],
save_interval: Union[str, int, Time]) -> Time:
def validate_interval(
interval: Union[str, int, Time],
save_interval: Union[str, int, Time],
) -> Time:

new_save_interval = Time.from_input(save_interval, TimeUnit.EPOCH)
async_interval = Time.from_input(interval, TimeUnit.EPOCH)

if new_save_interval.unit != async_interval.unit:
raise ValueError(
'Save interval and async eval interval must be in the same unit')
'Save interval and async eval interval must be in the same unit',
)
if async_interval < new_save_interval:
raise ValueError(
'Async eval interval must be equal or greater (less frequent) than save interval'
'Async eval interval must be equal or greater (less frequent) than save interval',
)
if async_interval.value % new_save_interval.value != 0:
raise ValueError(
'Async eval interval must be a multiple of save interval')
'Async eval interval must be a multiple of save interval',
)
return async_interval


def validate_eval_run_config(
eval_run_config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
eval_run_config: Optional[Dict[str, Any]],
) -> Dict[str, Any]:

if not eval_run_config:
return {}
Expand All @@ -172,7 +181,8 @@ def validate_eval_run_config(
if found_unsupported:
raise ValueError(
f'Unsupported eval run config keys found: {", ".join(found_unsupported)}'
+ f'. Supported keys: {supported_keys}')
+ f'. Supported keys: {supported_keys}',
)

return run_config

Expand Down Expand Up @@ -222,7 +232,7 @@ def __init__(

if '/' in train_config.get('save_filename', ''):
raise ValueError(
'AsyncEval not supported for save_filename that includes a path'
'AsyncEval not supported for save_filename that includes a path',
)

self.checkpoint_save_folder = train_config['save_folder']
Expand All @@ -237,14 +247,18 @@ def __init__(
)

# Validate the interval (how often to launch eval runs)
self.interval = validate_interval(interval,
self.training_params['save_interval'])
self.interval = validate_interval(
interval,
self.training_params['save_interval'],
)

# Configures how often to check for new checkpoints. This is semi-arbitrary;
# really we just want to check often enough to pull relevant checkpoints
# but not so often that we're constantly checking
check_interval_value = max(self.interval.value // CHECKS_PER_INTERVAL,
1)
check_interval_value = max(
self.interval.value // CHECKS_PER_INTERVAL,
1,
)
self.check_interval = Time(check_interval_value, self.interval.unit)

# Keep track of checkpoints that have already been evaled
Expand All @@ -260,8 +274,10 @@ def __init__(
include_end_of_training=False,
)

log.info('Initialized AsyncEval callback. Will generate runs at ' +
f'interval {interval}, checking at {self.check_interval}')
log.info(
'Initialized AsyncEval callback. Will generate runs at ' +
f'interval {interval}, checking at {self.check_interval}',
)

def state_dict(self) -> Dict[str, Any]:
checkpoints_evaled = []
Expand All @@ -284,7 +300,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]):
self.checkpoints_evaled[eval_ts] = (checkpoint, run_name)

log.info(
f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}'
f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}',
)

@staticmethod
Expand Down Expand Up @@ -318,12 +334,12 @@ def _get_ready_sharded_checkpoints(

# expecting one shard per gpu + 1 for metadata
expected_shard_count = dist.get_world_size() + 1
if remote_file_group_counts[
checkpoint_ts_path] != expected_shard_count:
if remote_file_group_counts[checkpoint_ts_path
] != expected_shard_count:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded (missing shards '
+
f'{remote_file_group_counts[checkpoint_ts_path]}/{expected_shard_count}), skipping'
f'{remote_file_group_counts[checkpoint_ts_path]}/{expected_shard_count}), skipping',
)
continue

Expand Down Expand Up @@ -357,7 +373,8 @@ def _get_ready_single_checkpoints(

if checkpoint not in unique_remote_checkpoints:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded, skipping')
f'Checkpoint {checkpoint} not fully uploaded, skipping',
)
continue

checkpoints_to_eval[checkpoint_ts_path] = checkpoint_ts
Expand All @@ -379,20 +396,23 @@ def _get_checkpoints_and_launch_runs(self, state: State):
checkpointer = callback
else:
log.warning(
'Multiple checkpoint savers found. Using the first one')
'Multiple checkpoint savers found. Using the first one',
)

if not checkpointer:
warnings.warn('No checkpoint saver callback found. Skipping eval')
return

if not checkpointer.all_saved_checkpoints_to_timestamp:
log.debug(
'No saved checkpoints found on the checkpointer. Skipping eval')
'No saved checkpoints found on the checkpointer. Skipping eval',
)
return

log.debug(
f'Found {len(checkpointer.all_saved_checkpoints_to_timestamp)} ' +
f'checkpoints: {checkpointer.all_saved_checkpoints_to_timestamp}')
f'checkpoints: {checkpointer.all_saved_checkpoints_to_timestamp}',
)

remote_checkpoints = list_remote_objects(self.checkpoint_save_folder)

Expand All @@ -403,19 +423,22 @@ def _get_checkpoints_and_launch_runs(self, state: State):
if state.fsdp_sharded_state_dict_enabled:
checkpoints_to_eval = self._get_ready_sharded_checkpoints(
checkpointer.all_saved_checkpoints_to_timestamp,
remote_checkpoints)
remote_checkpoints,
)
else:
checkpoints_to_eval = self._get_ready_single_checkpoints(
checkpointer.all_saved_checkpoints_to_timestamp,
remote_checkpoints)
remote_checkpoints,
)

for checkpoint_interval_path, checkpoint_timestamp in checkpoints_to_eval.items(
):
checkpoint_ts = checkpoint_timestamp.get(self.interval.unit)
if checkpoint_ts.value % self.interval.value != 0:
log.debug(
f'Checkpoint {checkpoint_interval_path} ({checkpoint_ts}) is '
+ f'not at an eval interval ({self.interval}), skipping')
+ f'not at an eval interval ({self.interval}), skipping',
)
continue
if checkpoint_ts in self.checkpoints_evaled:
continue # Skip checkpoints that have already been evaled
Expand Down Expand Up @@ -454,7 +477,9 @@ def close(self, state: State, logger: Logger) -> None:
latest_timestamp = state.timestamp.get(self.interval.unit)
if latest_timestamp not in self.checkpoints_evaled:
save_latest_filename = self.training_params.get(
'save_latest_filename', None)
'save_latest_filename',
None,
)

if not save_latest_filename:
rank = dist.get_global_rank()
Expand All @@ -463,11 +488,13 @@ def close(self, state: State, logger: Logger) -> None:
checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'

eval_run = self.launch_run(checkpoint, latest_timestamp)
self.checkpoints_evaled[latest_timestamp] = (checkpoint,
eval_run.name)
self.checkpoints_evaled[latest_timestamp] = (
checkpoint,
eval_run.name,
)

log.info(
f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:'
f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:',
)
for checkpoint_ts, (checkpoint,
run_name) in self.checkpoints_evaled.items():
Expand All @@ -477,13 +504,13 @@ def _get_current_run(self) -> Run:
if os.environ.get(MOSAICML_PLATFORM_ENV_VAR,
'false').lower() == 'false':
raise RuntimeError(
'AsyncEval callback is only supported when running on the MosaicML platform'
'AsyncEval callback is only supported when running on the MosaicML platform',
)

run_name = os.environ.get(RUN_NAME_ENV_VAR, None)
if not run_name:
raise RuntimeError(
'RUN_NAME environment variable must be set to use the AsyncEval callback'
'RUN_NAME environment variable must be set to use the AsyncEval callback',
)

# Allows the MapiException to be raised if the run doesn't exist
Expand Down Expand Up @@ -542,7 +569,8 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
'No github integration found for llm-foundry. Adding installation '
+ f'to eval run for latest foundry release ({version}). ' +
'To use a fork, custom branch, or custom version, configure ' +
'llm-foundry installation through a github integration')
'llm-foundry installation through a github integration',
)
integrations.append({
'integration_type': 'git_repo',
'git_repo': 'mosaicml/llm-foundry',
Expand Down
Loading

0 comments on commit b1463a1

Please sign in to comment.