Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add line splitting and other linting #1161

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
default_language_version:
python: python3
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.2.2
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
Expand Down
33 changes: 25 additions & 8 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,37 @@
from llmfoundry.utils.logging_utils import SpecificWarningFilter

# Filter out Hugging Face warning for not using a pinned revision of the model
hf_dynamic_modules_logger = logging.getLogger(
'transformers.dynamic_module_utils')
logger = logging.getLogger('transformers.dynamic_module_utils')
new_files_warning_filter = SpecificWarningFilter(
'A new version of the following files was downloaded from')
'A new version of the following files was downloaded from',
)

hf_dynamic_modules_logger.addFilter(new_files_warning_filter)
logger.addFilter(new_files_warning_filter)

from llmfoundry import (algorithms, callbacks, cli, data, eval, interfaces,
loggers, metrics, models, optim, tokenizers, utils)
from llmfoundry import (
algorithms,
callbacks,
cli,
data,
eval,
interfaces,
loggers,
metrics,
models,
optim,
tokenizers,
utils,
)
from llmfoundry.data import StreamingFinetuningDataset, StreamingTextDataset
from llmfoundry.eval import InContextLearningDataset, InContextLearningMetric
from llmfoundry.models.hf import ComposerHFCausalLM
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
MPTForCausalLM, MPTModel, MPTPreTrainedModel)
from llmfoundry.models.mpt import (
ComposerMPTCausalLM,
MPTConfig,
MPTForCausalLM,
MPTModel,
MPTPreTrainedModel,
)
from llmfoundry.optim import DecoupledLionW

__all__ = [
Expand Down
8 changes: 6 additions & 2 deletions llmfoundry/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from composer.algorithms import (Alibi, GatedLinearUnits, GradientClipping,
LowPrecisionLayerNorm)
from composer.algorithms import (
Alibi,
GatedLinearUnits,
GradientClipping,
LowPrecisionLayerNorm,
)

from llmfoundry.registry import algorithms

Expand Down
22 changes: 16 additions & 6 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from composer.callbacks import (EarlyStopper, EvalOutputLogging, Generate,
LRMonitor, MemoryMonitor, MemorySnapshot,
OOMObserver, OptimizerMonitor, RuntimeEstimator,
SpeedMonitor)
from composer.callbacks import (
EarlyStopper,
EvalOutputLogging,
Generate,
LRMonitor,
MemoryMonitor,
MemorySnapshot,
OOMObserver,
OptimizerMonitor,
RuntimeEstimator,
SpeedMonitor,
)

from llmfoundry.callbacks.async_eval_callback import AsyncEval
from llmfoundry.callbacks.curriculum_learning_callback import CurriculumLearning
Expand All @@ -15,8 +23,10 @@
MegaBlocksMoE_TokPerExpert
from llmfoundry.callbacks.monolithic_ckpt_callback import \
MonolithicCheckpointSaver
from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
LayerFreezing)
from llmfoundry.callbacks.resumption_callbacks import (
GlobalLRScaling,
LayerFreezing,
)
from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector
from llmfoundry.registry import callbacks, callbacks_with_config

Expand Down
106 changes: 67 additions & 39 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from composer.callbacks import CheckpointSaver
from composer.core import Event, State, Time, Timestamp, TimeUnit
from composer.loggers import Logger
from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR,
RUN_NAME_ENV_VAR)
from composer.loggers.mosaicml_logger import (
MOSAICML_PLATFORM_ENV_VAR,
RUN_NAME_ENV_VAR,
)
from composer.utils import dist
from composer.utils.file_helpers import list_remote_objects
from composer.utils.misc import create_interval_scheduler
Expand Down Expand Up @@ -65,15 +67,17 @@ def get_run_name(training_run_name: str, current_interval: str) -> str:
"""
name_without_uuid_suffix = training_run_name.rsplit('-', 1)[0]

max_length = MAX_RUN_NAME_BASE_LENGTH - len(RUN_NAME_PREFIX) - len(
current_interval) - 2
max_length = MAX_RUN_NAME_BASE_LENGTH - len(
RUN_NAME_PREFIX,
) - len(current_interval) - 2

# A run name that is too long will fail a createRun call
if len(name_without_uuid_suffix) > max_length:
new_name = name_without_uuid_suffix[:max_length]
log.warning(
f'Training run name {name_without_uuid_suffix} may be too long,' +
f' truncating to {new_name}')
f' truncating to {new_name}',
)
name_without_uuid_suffix = new_name

return f'{RUN_NAME_PREFIX}-{current_interval}-{name_without_uuid_suffix}'
Expand Down Expand Up @@ -107,7 +111,7 @@ def get_eval_parameters(

if looking_for:
raise Exception(
f'Missing the following required parameters for async eval: {looking_for}'
f'Missing the following required parameters for async eval: {looking_for}',
)

for logger, config in subset_keys.get('loggers', {}).items():
Expand All @@ -126,7 +130,7 @@ def get_eval_parameters(
new_models = {
'model_name': model_name,
'model': model,
'load_path': checkpoint
'load_path': checkpoint,
}

tokenizer = subset_keys.pop('tokenizer', None)
Expand All @@ -136,27 +140,32 @@ def get_eval_parameters(
return subset_keys


def validate_interval(interval: Union[str, int, Time],
save_interval: Union[str, int, Time]) -> Time:
def validate_interval(
interval: Union[str, int, Time],
save_interval: Union[str, int, Time],
) -> Time:

new_save_interval = Time.from_input(save_interval, TimeUnit.EPOCH)
async_interval = Time.from_input(interval, TimeUnit.EPOCH)

if new_save_interval.unit != async_interval.unit:
raise ValueError(
'Save interval and async eval interval must be in the same unit')
'Save interval and async eval interval must be in the same unit',
)
if async_interval < new_save_interval:
raise ValueError(
'Async eval interval must be equal or greater (less frequent) than save interval'
'Async eval interval must be equal or greater (less frequent) than save interval',
)
if async_interval.value % new_save_interval.value != 0:
raise ValueError(
'Async eval interval must be a multiple of save interval')
'Async eval interval must be a multiple of save interval',
)
return async_interval


def validate_eval_run_config(
eval_run_config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
eval_run_config: Optional[Dict[str, Any]],
) -> Dict[str, Any]:

if not eval_run_config:
return {}
Expand All @@ -172,7 +181,8 @@ def validate_eval_run_config(
if found_unsupported:
raise ValueError(
f'Unsupported eval run config keys found: {", ".join(found_unsupported)}'
+ f'. Supported keys: {supported_keys}')
+ f'. Supported keys: {supported_keys}',
)

return run_config

Expand Down Expand Up @@ -222,7 +232,7 @@ def __init__(

if '/' in train_config.get('save_filename', ''):
raise ValueError(
'AsyncEval not supported for save_filename that includes a path'
'AsyncEval not supported for save_filename that includes a path',
)

self.checkpoint_save_folder = train_config['save_folder']
Expand All @@ -237,14 +247,18 @@ def __init__(
)

# Validate the interval (how often to launch eval runs)
self.interval = validate_interval(interval,
self.training_params['save_interval'])
self.interval = validate_interval(
interval,
self.training_params['save_interval'],
)

# Configures how often to check for new checkpoints. This is semi-arbitrary;
# really we just want to check often enough to pull relevant checkpoints
# but not so often that we're constantly checking
check_interval_value = max(self.interval.value // CHECKS_PER_INTERVAL,
1)
check_interval_value = max(
self.interval.value // CHECKS_PER_INTERVAL,
1,
)
self.check_interval = Time(check_interval_value, self.interval.unit)

# Keep track of checkpoints that have already been evaled
Expand All @@ -260,8 +274,10 @@ def __init__(
include_end_of_training=False,
)

log.info('Initialized AsyncEval callback. Will generate runs at ' +
f'interval {interval}, checking at {self.check_interval}')
log.info(
'Initialized AsyncEval callback. Will generate runs at ' +
f'interval {interval}, checking at {self.check_interval}',
)

def state_dict(self) -> Dict[str, Any]:
checkpoints_evaled = []
Expand All @@ -284,7 +300,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]):
self.checkpoints_evaled[eval_ts] = (checkpoint, run_name)

log.info(
f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}'
f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}',
)

@staticmethod
Expand Down Expand Up @@ -318,12 +334,12 @@ def _get_ready_sharded_checkpoints(

# expecting one shard per gpu + 1 for metadata
expected_shard_count = dist.get_world_size() + 1
if remote_file_group_counts[
checkpoint_ts_path] != expected_shard_count:
if remote_file_group_counts[checkpoint_ts_path
] != expected_shard_count:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded (missing shards '
+
f'{remote_file_group_counts[checkpoint_ts_path]}/{expected_shard_count}), skipping'
f'{remote_file_group_counts[checkpoint_ts_path]}/{expected_shard_count}), skipping',
)
continue

Expand Down Expand Up @@ -357,7 +373,8 @@ def _get_ready_single_checkpoints(

if checkpoint not in unique_remote_checkpoints:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded, skipping')
f'Checkpoint {checkpoint} not fully uploaded, skipping',
)
continue

checkpoints_to_eval[checkpoint_ts_path] = checkpoint_ts
Expand All @@ -379,20 +396,23 @@ def _get_checkpoints_and_launch_runs(self, state: State):
checkpointer = callback
else:
log.warning(
'Multiple checkpoint savers found. Using the first one')
'Multiple checkpoint savers found. Using the first one',
)

if not checkpointer:
warnings.warn('No checkpoint saver callback found. Skipping eval')
return

if not checkpointer.all_saved_checkpoints_to_timestamp:
log.debug(
'No saved checkpoints found on the checkpointer. Skipping eval')
'No saved checkpoints found on the checkpointer. Skipping eval',
)
return

log.debug(
f'Found {len(checkpointer.all_saved_checkpoints_to_timestamp)} ' +
f'checkpoints: {checkpointer.all_saved_checkpoints_to_timestamp}')
f'checkpoints: {checkpointer.all_saved_checkpoints_to_timestamp}',
)

remote_checkpoints = list_remote_objects(self.checkpoint_save_folder)

Expand All @@ -403,19 +423,22 @@ def _get_checkpoints_and_launch_runs(self, state: State):
if state.fsdp_sharded_state_dict_enabled:
checkpoints_to_eval = self._get_ready_sharded_checkpoints(
checkpointer.all_saved_checkpoints_to_timestamp,
remote_checkpoints)
remote_checkpoints,
)
else:
checkpoints_to_eval = self._get_ready_single_checkpoints(
checkpointer.all_saved_checkpoints_to_timestamp,
remote_checkpoints)
remote_checkpoints,
)

for checkpoint_interval_path, checkpoint_timestamp in checkpoints_to_eval.items(
):
checkpoint_ts = checkpoint_timestamp.get(self.interval.unit)
if checkpoint_ts.value % self.interval.value != 0:
log.debug(
f'Checkpoint {checkpoint_interval_path} ({checkpoint_ts}) is '
+ f'not at an eval interval ({self.interval}), skipping')
+ f'not at an eval interval ({self.interval}), skipping',
)
continue
if checkpoint_ts in self.checkpoints_evaled:
continue # Skip checkpoints that have already been evaled
Expand Down Expand Up @@ -454,7 +477,9 @@ def close(self, state: State, logger: Logger) -> None:
latest_timestamp = state.timestamp.get(self.interval.unit)
if latest_timestamp not in self.checkpoints_evaled:
save_latest_filename = self.training_params.get(
'save_latest_filename', None)
'save_latest_filename',
None,
)

if not save_latest_filename:
rank = dist.get_global_rank()
Expand All @@ -463,11 +488,13 @@ def close(self, state: State, logger: Logger) -> None:
checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'

eval_run = self.launch_run(checkpoint, latest_timestamp)
self.checkpoints_evaled[latest_timestamp] = (checkpoint,
eval_run.name)
self.checkpoints_evaled[latest_timestamp] = (
checkpoint,
eval_run.name,
)

log.info(
f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:'
f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:',
)
for checkpoint_ts, (checkpoint,
run_name) in self.checkpoints_evaled.items():
Expand All @@ -477,13 +504,13 @@ def _get_current_run(self) -> Run:
if os.environ.get(MOSAICML_PLATFORM_ENV_VAR,
'false').lower() == 'false':
raise RuntimeError(
'AsyncEval callback is only supported when running on the MosaicML platform'
'AsyncEval callback is only supported when running on the MosaicML platform',
)

run_name = os.environ.get(RUN_NAME_ENV_VAR, None)
if not run_name:
raise RuntimeError(
'RUN_NAME environment variable must be set to use the AsyncEval callback'
'RUN_NAME environment variable must be set to use the AsyncEval callback',
)

# Allows the MapiException to be raised if the run doesn't exist
Expand Down Expand Up @@ -542,7 +569,8 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
'No github integration found for llm-foundry. Adding installation '
+ f'to eval run for latest foundry release ({version}). ' +
'To use a fork, custom branch, or custom version, configure ' +
'llm-foundry installation through a github integration')
'llm-foundry installation through a github integration',
)
integrations.append({
'integration_type': 'git_repo',
'git_repo': 'mosaicml/llm-foundry',
Expand Down
Loading
Loading