diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3551482244..dc2e3f55cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,12 @@ default_language_version: python: python3 repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.2.2 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/google/yapf rev: v0.32.0 hooks: diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 012147ec20..c32b9736df 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -12,20 +12,37 @@ from llmfoundry.utils.logging_utils import SpecificWarningFilter # Filter out Hugging Face warning for not using a pinned revision of the model -hf_dynamic_modules_logger = logging.getLogger( - 'transformers.dynamic_module_utils') +logger = logging.getLogger('transformers.dynamic_module_utils') new_files_warning_filter = SpecificWarningFilter( - 'A new version of the following files was downloaded from') + 'A new version of the following files was downloaded from', +) -hf_dynamic_modules_logger.addFilter(new_files_warning_filter) +logger.addFilter(new_files_warning_filter) -from llmfoundry import (algorithms, callbacks, cli, data, eval, interfaces, - loggers, metrics, models, optim, tokenizers, utils) +from llmfoundry import ( + algorithms, + callbacks, + cli, + data, + eval, + interfaces, + loggers, + metrics, + models, + optim, + tokenizers, + utils, +) from llmfoundry.data import StreamingFinetuningDataset, StreamingTextDataset from llmfoundry.eval import InContextLearningDataset, InContextLearningMetric from llmfoundry.models.hf import ComposerHFCausalLM -from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig, - MPTForCausalLM, MPTModel, MPTPreTrainedModel) +from llmfoundry.models.mpt import ( + ComposerMPTCausalLM, + MPTConfig, + MPTForCausalLM, + MPTModel, + MPTPreTrainedModel, +) from llmfoundry.optim import DecoupledLionW __all__ = [ diff --git a/llmfoundry/algorithms/__init__.py b/llmfoundry/algorithms/__init__.py index 78e3d270fd..6459c2f0fb 100644 --- a/llmfoundry/algorithms/__init__.py +++ b/llmfoundry/algorithms/__init__.py @@ -1,8 +1,12 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from composer.algorithms import (Alibi, GatedLinearUnits, GradientClipping, - LowPrecisionLayerNorm) +from composer.algorithms import ( + Alibi, + GatedLinearUnits, + GradientClipping, + LowPrecisionLayerNorm, +) from llmfoundry.registry import algorithms diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index dc3ee707ac..d1c79dd82c 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -1,10 +1,18 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from composer.callbacks import (EarlyStopper, EvalOutputLogging, Generate, - LRMonitor, MemoryMonitor, MemorySnapshot, - OOMObserver, OptimizerMonitor, RuntimeEstimator, - SpeedMonitor) +from composer.callbacks import ( + EarlyStopper, + EvalOutputLogging, + Generate, + LRMonitor, + MemoryMonitor, + MemorySnapshot, + OOMObserver, + OptimizerMonitor, + RuntimeEstimator, + SpeedMonitor, +) from llmfoundry.callbacks.async_eval_callback import AsyncEval from llmfoundry.callbacks.curriculum_learning_callback import CurriculumLearning @@ -15,8 +23,10 @@ MegaBlocksMoE_TokPerExpert from llmfoundry.callbacks.monolithic_ckpt_callback import \ MonolithicCheckpointSaver -from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling, - LayerFreezing) +from llmfoundry.callbacks.resumption_callbacks import ( + GlobalLRScaling, + LayerFreezing, +) from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector from llmfoundry.registry import callbacks, callbacks_with_config diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index dc0a1b10f0..646d86c8d3 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -16,8 +16,10 @@ from composer.callbacks import CheckpointSaver from composer.core import Event, State, Time, Timestamp, TimeUnit from composer.loggers import Logger -from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR, - RUN_NAME_ENV_VAR) +from composer.loggers.mosaicml_logger import ( + MOSAICML_PLATFORM_ENV_VAR, + RUN_NAME_ENV_VAR, +) from composer.utils import dist from composer.utils.file_helpers import list_remote_objects from composer.utils.misc import create_interval_scheduler @@ -65,15 +67,17 @@ def get_run_name(training_run_name: str, current_interval: str) -> str: """ name_without_uuid_suffix = training_run_name.rsplit('-', 1)[0] - max_length = MAX_RUN_NAME_BASE_LENGTH - len(RUN_NAME_PREFIX) - len( - current_interval) - 2 + max_length = MAX_RUN_NAME_BASE_LENGTH - len( + RUN_NAME_PREFIX, + ) - len(current_interval) - 2 # A run name that is too long will fail a createRun call if len(name_without_uuid_suffix) > max_length: new_name = name_without_uuid_suffix[:max_length] log.warning( f'Training run name {name_without_uuid_suffix} may be too long,' + - f' truncating to {new_name}') + f' truncating to {new_name}', + ) name_without_uuid_suffix = new_name return f'{RUN_NAME_PREFIX}-{current_interval}-{name_without_uuid_suffix}' @@ -107,7 +111,7 @@ def get_eval_parameters( if looking_for: raise Exception( - f'Missing the following required parameters for async eval: {looking_for}' + f'Missing the following required parameters for async eval: {looking_for}', ) for logger, config in subset_keys.get('loggers', {}).items(): @@ -126,7 +130,7 @@ def get_eval_parameters( new_models = { 'model_name': model_name, 'model': model, - 'load_path': checkpoint + 'load_path': checkpoint, } tokenizer = subset_keys.pop('tokenizer', None) @@ -136,27 +140,32 @@ def get_eval_parameters( return subset_keys -def validate_interval(interval: Union[str, int, Time], - save_interval: Union[str, int, Time]) -> Time: +def validate_interval( + interval: Union[str, int, Time], + save_interval: Union[str, int, Time], +) -> Time: new_save_interval = Time.from_input(save_interval, TimeUnit.EPOCH) async_interval = Time.from_input(interval, TimeUnit.EPOCH) if new_save_interval.unit != async_interval.unit: raise ValueError( - 'Save interval and async eval interval must be in the same unit') + 'Save interval and async eval interval must be in the same unit', + ) if async_interval < new_save_interval: raise ValueError( - 'Async eval interval must be equal or greater (less frequent) than save interval' + 'Async eval interval must be equal or greater (less frequent) than save interval', ) if async_interval.value % new_save_interval.value != 0: raise ValueError( - 'Async eval interval must be a multiple of save interval') + 'Async eval interval must be a multiple of save interval', + ) return async_interval def validate_eval_run_config( - eval_run_config: Optional[Dict[str, Any]]) -> Dict[str, Any]: + eval_run_config: Optional[Dict[str, Any]], +) -> Dict[str, Any]: if not eval_run_config: return {} @@ -172,7 +181,8 @@ def validate_eval_run_config( if found_unsupported: raise ValueError( f'Unsupported eval run config keys found: {", ".join(found_unsupported)}' - + f'. Supported keys: {supported_keys}') + + f'. Supported keys: {supported_keys}', + ) return run_config @@ -222,7 +232,7 @@ def __init__( if '/' in train_config.get('save_filename', ''): raise ValueError( - 'AsyncEval not supported for save_filename that includes a path' + 'AsyncEval not supported for save_filename that includes a path', ) self.checkpoint_save_folder = train_config['save_folder'] @@ -237,14 +247,18 @@ def __init__( ) # Validate the interval (how often to launch eval runs) - self.interval = validate_interval(interval, - self.training_params['save_interval']) + self.interval = validate_interval( + interval, + self.training_params['save_interval'], + ) # Configures how often to check for new checkpoints. This is semi-arbitrary; # really we just want to check often enough to pull relevant checkpoints # but not so often that we're constantly checking - check_interval_value = max(self.interval.value // CHECKS_PER_INTERVAL, - 1) + check_interval_value = max( + self.interval.value // CHECKS_PER_INTERVAL, + 1, + ) self.check_interval = Time(check_interval_value, self.interval.unit) # Keep track of checkpoints that have already been evaled @@ -260,8 +274,10 @@ def __init__( include_end_of_training=False, ) - log.info('Initialized AsyncEval callback. Will generate runs at ' + - f'interval {interval}, checking at {self.check_interval}') + log.info( + 'Initialized AsyncEval callback. Will generate runs at ' + + f'interval {interval}, checking at {self.check_interval}', + ) def state_dict(self) -> Dict[str, Any]: checkpoints_evaled = [] @@ -284,7 +300,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]): self.checkpoints_evaled[eval_ts] = (checkpoint, run_name) log.info( - f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}' + f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}', ) @staticmethod @@ -318,12 +334,12 @@ def _get_ready_sharded_checkpoints( # expecting one shard per gpu + 1 for metadata expected_shard_count = dist.get_world_size() + 1 - if remote_file_group_counts[ - checkpoint_ts_path] != expected_shard_count: + if remote_file_group_counts[checkpoint_ts_path + ] != expected_shard_count: log.debug( f'Checkpoint {checkpoint} not fully uploaded (missing shards ' + - f'{remote_file_group_counts[checkpoint_ts_path]}/{expected_shard_count}), skipping' + f'{remote_file_group_counts[checkpoint_ts_path]}/{expected_shard_count}), skipping', ) continue @@ -357,7 +373,8 @@ def _get_ready_single_checkpoints( if checkpoint not in unique_remote_checkpoints: log.debug( - f'Checkpoint {checkpoint} not fully uploaded, skipping') + f'Checkpoint {checkpoint} not fully uploaded, skipping', + ) continue checkpoints_to_eval[checkpoint_ts_path] = checkpoint_ts @@ -379,7 +396,8 @@ def _get_checkpoints_and_launch_runs(self, state: State): checkpointer = callback else: log.warning( - 'Multiple checkpoint savers found. Using the first one') + 'Multiple checkpoint savers found. Using the first one', + ) if not checkpointer: warnings.warn('No checkpoint saver callback found. Skipping eval') @@ -387,12 +405,14 @@ def _get_checkpoints_and_launch_runs(self, state: State): if not checkpointer.all_saved_checkpoints_to_timestamp: log.debug( - 'No saved checkpoints found on the checkpointer. Skipping eval') + 'No saved checkpoints found on the checkpointer. Skipping eval', + ) return log.debug( f'Found {len(checkpointer.all_saved_checkpoints_to_timestamp)} ' + - f'checkpoints: {checkpointer.all_saved_checkpoints_to_timestamp}') + f'checkpoints: {checkpointer.all_saved_checkpoints_to_timestamp}', + ) remote_checkpoints = list_remote_objects(self.checkpoint_save_folder) @@ -403,11 +423,13 @@ def _get_checkpoints_and_launch_runs(self, state: State): if state.fsdp_sharded_state_dict_enabled: checkpoints_to_eval = self._get_ready_sharded_checkpoints( checkpointer.all_saved_checkpoints_to_timestamp, - remote_checkpoints) + remote_checkpoints, + ) else: checkpoints_to_eval = self._get_ready_single_checkpoints( checkpointer.all_saved_checkpoints_to_timestamp, - remote_checkpoints) + remote_checkpoints, + ) for checkpoint_interval_path, checkpoint_timestamp in checkpoints_to_eval.items( ): @@ -415,7 +437,8 @@ def _get_checkpoints_and_launch_runs(self, state: State): if checkpoint_ts.value % self.interval.value != 0: log.debug( f'Checkpoint {checkpoint_interval_path} ({checkpoint_ts}) is ' - + f'not at an eval interval ({self.interval}), skipping') + + f'not at an eval interval ({self.interval}), skipping', + ) continue if checkpoint_ts in self.checkpoints_evaled: continue # Skip checkpoints that have already been evaled @@ -454,7 +477,9 @@ def close(self, state: State, logger: Logger) -> None: latest_timestamp = state.timestamp.get(self.interval.unit) if latest_timestamp not in self.checkpoints_evaled: save_latest_filename = self.training_params.get( - 'save_latest_filename', None) + 'save_latest_filename', + None, + ) if not save_latest_filename: rank = dist.get_global_rank() @@ -463,11 +488,13 @@ def close(self, state: State, logger: Logger) -> None: checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}' eval_run = self.launch_run(checkpoint, latest_timestamp) - self.checkpoints_evaled[latest_timestamp] = (checkpoint, - eval_run.name) + self.checkpoints_evaled[latest_timestamp] = ( + checkpoint, + eval_run.name, + ) log.info( - f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:' + f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:', ) for checkpoint_ts, (checkpoint, run_name) in self.checkpoints_evaled.items(): @@ -477,13 +504,13 @@ def _get_current_run(self) -> Run: if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'false': raise RuntimeError( - 'AsyncEval callback is only supported when running on the MosaicML platform' + 'AsyncEval callback is only supported when running on the MosaicML platform', ) run_name = os.environ.get(RUN_NAME_ENV_VAR, None) if not run_name: raise RuntimeError( - 'RUN_NAME environment variable must be set to use the AsyncEval callback' + 'RUN_NAME environment variable must be set to use the AsyncEval callback', ) # Allows the MapiException to be raised if the run doesn't exist @@ -542,7 +569,8 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run: 'No github integration found for llm-foundry. Adding installation ' + f'to eval run for latest foundry release ({version}). ' + 'To use a fork, custom branch, or custom version, configure ' + - 'llm-foundry installation through a github integration') + 'llm-foundry installation through a github integration', + ) integrations.append({ 'integration_type': 'git_repo', 'git_repo': 'mosaicml/llm-foundry', diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 3400ee8bb1..1fb059070f 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -51,13 +51,15 @@ def before_load(self, state: State, logger: Logger): if not isinstance(train_loader, DataLoader): raise ValueError( f'CurriculumLearning callback can only be used with a train ', - f'dataloader of type DataLoader, but got {type(train_loader)}.') + f'dataloader of type DataLoader, but got {type(train_loader)}.', + ) dataset = train_loader.dataset if not isinstance(dataset, StreamingDataset): raise ValueError( f'CurriculumLearning callback only supports StreamingDataset ', f'because it requires loading and saving dataset state. ', - f'Instead, got a dataset of type {type(dataset)}') + f'Instead, got a dataset of type {type(dataset)}', + ) assert isinstance(dataset, StreamingDataset) # Save the current dataset state so we can restore it if needed. self.current_dataset_state = dataset.state_dict( # type: ignore @@ -72,10 +74,12 @@ def after_load(self, state: State, logger: Logger): train_loader = state._train_dataloader assert isinstance( train_loader, - DataLoader), 'CurriculumLearning callback requires a DataLoader.' + DataLoader, + ), 'CurriculumLearning callback requires a DataLoader.' dataset = train_loader.dataset assert isinstance( - dataset, StreamingDataset + dataset, + StreamingDataset, ), 'CurriculumLearning callback requires a StreamingDataset.' if self.saved_dataset_index < self.dataset_index: # Ignore the dataset state that was read in from the checkpoint, and @@ -101,7 +105,7 @@ def after_load(self, state: State, logger: Logger): def state_dict(self): return { 'dataset_index': self.dataset_index, - 'all_dataset_configs': self.all_dataset_configs + 'all_dataset_configs': self.all_dataset_configs, } def load_state_dict(self, state: Dict[str, Any]): diff --git a/llmfoundry/callbacks/eval_gauntlet_callback.py b/llmfoundry/callbacks/eval_gauntlet_callback.py index 7544d66040..4d0f685ecd 100644 --- a/llmfoundry/callbacks/eval_gauntlet_callback.py +++ b/llmfoundry/callbacks/eval_gauntlet_callback.py @@ -22,8 +22,10 @@ class Weighting(Enum): LOG_SAMPLE_SZ = 3 -def calculate_named_averages(average_names: Dict[str, list], - category_scores: Dict[str, float]): +def calculate_named_averages( + average_names: Dict[str, list], + category_scores: Dict[str, float], +): """Calculates the named averages based off the raw category scores. For each named average, take a simple average of all the category scores associated with that named average. @@ -40,8 +42,9 @@ def calculate_named_averages(average_names: Dict[str, list], if category in category_list } if len(composite_subset.values()) > 0: - average_scores[avg_name] = sum(composite_subset.values()) / len( - composite_subset.values()) + average_scores[avg_name] = sum( + composite_subset.values(), + ) / len(composite_subset.values()) else: average_scores[avg_name] = 0 @@ -72,25 +75,28 @@ class EvalGauntlet(Callback): averages (Optional[dict]): Optional dictionary specifying a mapping from a average names to lists of categories used produce each named average. """ - def __init__(self, - logger_keys: list, - categories: dict, - weighting: str = 'EQUAL', - subtract_random_baseline: bool = True, - rescale_accuracy: bool = True, - benchmark_sizes: Optional[dict] = None, - averages: Optional[dict] = None): + def __init__( + self, + logger_keys: list, + categories: dict, + weighting: str = 'EQUAL', + subtract_random_baseline: bool = True, + rescale_accuracy: bool = True, + benchmark_sizes: Optional[dict] = None, + averages: Optional[dict] = None, + ): if isinstance(logger_keys, dict): raise ValueError( - 'logger_keys now requires a list type as input, not a dict') + 'logger_keys now requires a list type as input, not a dict', + ) if weighting != Weighting.EQUAL and benchmark_sizes is None: raise Exception( - 'When not using equal weighting, you must provide the benchmark sizes.' + 'When not using equal weighting, you must provide the benchmark sizes.', ) if rescale_accuracy and not subtract_random_baseline: raise Exception( - 'Only use accuracy rescaling in conjunction with subtracting random baseline accuracy.' + 'Only use accuracy rescaling in conjunction with subtracting random baseline accuracy.', ) self.categories = categories @@ -106,8 +112,12 @@ def __init__(self, if self.weighting != Weighting.EQUAL: assert benchmark_sizes is not None cumulative_samples = max( - sum(count for name, count in benchmark_sizes.items() - if name.startswith(bench_name)), 1) + sum( + count for name, count in benchmark_sizes.items() + if name.startswith(bench_name) + ), + 1, + ) else: cumulative_samples = -1 # pyright @@ -131,7 +141,7 @@ def __init__(self, for avg_name in self.averages: if avg_name in self.category_names: raise ValueError( - f'Found average name `{avg_name}` used as category name. Average names and category names must be non-overlapping.' + f'Found average name `{avg_name}` used as category name. Average names and category names must be non-overlapping.', ) def extract_metrics_from_state(self, state: State) -> Dict[str, float]: @@ -172,7 +182,8 @@ def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]: if key not in computed_metrics: log.warning( - f'Could not find results for benchmark: {benchmark}.') + f'Could not find results for benchmark: {benchmark}.', + ) missing_metrics.append(key) else: score = computed_metrics[key] @@ -186,23 +197,27 @@ def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]: category_scores[category['name']].append({ 'name': benchmark['name'], 'score': score, - 'weighting': benchmark['weighting'] + 'weighting': benchmark['weighting'], }) if len(missing_metrics) > 0: log.warning( - f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}" + f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}", ) del category_scores[category['name']] continue total_weight = sum( - k['weighting'] for k in category_scores[category['name']]) + k['weighting'] for k in category_scores[category['name']] + ) category_scores[category['name']] = sum( k['score'] * (k['weighting'] / total_weight) - for k in category_scores[category['name']]) + for k in category_scores[category['name']] + ) - named_averages = calculate_named_averages(self.averages, - category_scores) + named_averages = calculate_named_averages( + self.averages, + category_scores, + ) category_scores.update(named_averages) category_scores = { f'icl/metrics/eval_gauntlet/{k}': v diff --git a/llmfoundry/callbacks/fdiff_callback.py b/llmfoundry/callbacks/fdiff_callback.py index 2afcc94452..4ab3e9f82d 100644 --- a/llmfoundry/callbacks/fdiff_callback.py +++ b/llmfoundry/callbacks/fdiff_callback.py @@ -18,9 +18,11 @@ class FDiffMetrics(Callback): numerical derivative of the metrics """ - def __init__(self, - diff_train_metrics: bool = False, - diff_eval_metrics: bool = True): + def __init__( + self, + diff_train_metrics: bool = False, + diff_eval_metrics: bool = True, + ): self.diff_train_metrics = diff_train_metrics self.diff_eval_metrics = diff_eval_metrics @@ -34,14 +36,16 @@ def batch_end(self, state: State, logger: Logger) -> None: raise NotImplementedError('Multiple losses not supported yet') loss = state.loss.item() if self.train_prev_loss: - logger.log_metrics( - {'loss/train/total_fdiff': loss - self.train_prev_loss}) + logger.log_metrics({ + 'loss/train/total_fdiff': loss - self.train_prev_loss, + }) self.train_prev_loss = loss for k in self.train_prev_metric.keys(): logger.log_metrics({ f'metrics/train/{k}_fdiff': - state.train_metric_values[k] - self.train_prev_metric[k] + state.train_metric_values[k] - + self.train_prev_metric[k], }) for k in state.train_metric_values.keys(): @@ -61,7 +65,7 @@ def eval_end(self, state: State, logger: Logger) -> None: logger.log_metrics({ f'{mkey}_fdiff': state.eval_metric_values[k] - - self.eval_prev_metric[mkey] + self.eval_prev_metric[mkey], }) for k in metrics: diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index feb9ff98b5..28b33b43d8 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -21,9 +21,12 @@ from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger from composer.models import HuggingFaceModel -from composer.utils import (dist, format_name_with_dist_and_time, - maybe_create_remote_uploader_downloader_from_uri, - parse_uri) +from composer.utils import ( + dist, + format_name_with_dist_and_time, + maybe_create_remote_uploader_downloader_from_uri, + parse_uri, +) from composer.utils.misc import create_interval_scheduler from mlflow.transformers import _fetch_model_card, _write_license_information from packaging import version @@ -42,8 +45,9 @@ def _maybe_get_license_filename( - local_dir: str, - pretrained_model_name: Optional[str] = None) -> Optional[str]: + local_dir: str, + pretrained_model_name: Optional[str] = None, +) -> Optional[str]: """Returns the name of the license file if it exists in the local_dir. Note: This is intended to be consistent with the code in MLflow. @@ -57,22 +61,29 @@ def _maybe_get_license_filename( If the license file does not exist, returns None. """ try: - license_filename = next(file for file in os.listdir(local_dir) - if _LICENSE_FILE_PATTERN.search(file)) + license_filename = next( + file for file in os.listdir(local_dir) + if _LICENSE_FILE_PATTERN.search(file) + ) # If a pretrained model name is provided, replace the license file with the correct info from HF Hub. if pretrained_model_name is not None: log.info( - f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub' + f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub', ) os.remove(os.path.join(local_dir, license_filename)) model_card = _fetch_model_card(pretrained_model_name) local_dir_path = Path(local_dir).absolute() - _write_license_information(pretrained_model_name, model_card, - local_dir_path) - license_filename = next(file for file in os.listdir(local_dir) - if _LICENSE_FILE_PATTERN.search(file)) + _write_license_information( + pretrained_model_name, + model_card, + local_dir_path, + ) + license_filename = next( + file for file in os.listdir(local_dir) + if _LICENSE_FILE_PATTERN.search(file) + ) return license_filename @@ -96,13 +107,16 @@ def _register_model_with_run_id_multiprocess( # If logging_level is 0, then the composer logger was unset. logging.basicConfig( format= - f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', ) logging.getLogger('composer').setLevel(composer_logging_level) # Register model. mlflow_logger.register_model_with_run_id( - model_uri=model_uri, name=name, await_creation_for=await_creation_for) + model_uri=model_uri, + name=name, + await_creation_for=await_creation_for, + ) class HuggingFaceCheckpointer(Callback): @@ -136,15 +150,15 @@ class HuggingFaceCheckpointer(Callback): """ def __init__( - self, - save_folder: str, - save_interval: Union[str, int, Time], - huggingface_folder_name: str = 'ba{batch}', - precision: str = 'float32', - overwrite: bool = True, - mlflow_registered_model_name: Optional[str] = None, - mlflow_logging_config: Optional[dict] = None, - flatten_imports: Sequence[str] = ('llmfoundry',), + self, + save_folder: str, + save_interval: Union[str, int, Time], + huggingface_folder_name: str = 'ba{batch}', + precision: str = 'float32', + overwrite: bool = True, + mlflow_registered_model_name: Optional[str] = None, + mlflow_logging_config: Optional[dict] = None, + flatten_imports: Sequence[str] = ('llmfoundry',), ): _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite @@ -168,33 +182,44 @@ def __init__( mlflow_logging_config.setdefault('task', 'llm/v1/completions') default_input_example = { - 'prompt': np.array(['What is Machine Learning?']) + 'prompt': np.array(['What is Machine Learning?']), } is_chat = mlflow_logging_config['task'].endswith('chat') or ( mlflow_logging_config['metadata'] is not None and mlflow_logging_config['metadata'].get('task', - '').endswith('chat')) + '').endswith('chat') + ) if is_chat: default_input_example = { 'messages': [{ 'role': 'user', - 'content': 'What is Machine Learning?' - }] + 'content': 'What is Machine Learning?', + }], } - mlflow_logging_config.setdefault('input_example', - default_input_example) + mlflow_logging_config.setdefault( + 'input_example', + default_input_example, + ) self.mlflow_logging_config = mlflow_logging_config self.huggingface_folder_name_fstr = os.path.join( - 'huggingface', huggingface_folder_name) + 'huggingface', + huggingface_folder_name, + ) - self.save_interval: Time = Time.from_input(save_interval, - TimeUnit.EPOCH) + self.save_interval: Time = Time.from_input( + save_interval, + TimeUnit.EPOCH, + ) self.check_interval = create_interval_scheduler( - self.save_interval, include_end_of_training=True) + self.save_interval, + include_end_of_training=True, + ) self.remote_ud = maybe_create_remote_uploader_downloader_from_uri( - save_folder, loggers=[]) + save_folder, + loggers=[], + ) if self.remote_ud is not None: self.remote_ud._num_concurrent_uploads = 4 @@ -208,14 +233,16 @@ def __init__( def run_event(self, event: Event, state: State, logger: Logger) -> None: # The interval scheduler handles only returning True for the appropriate events if state.get_elapsed_duration() is not None and self.check_interval( - state, - event) and self.last_checkpoint_batch != state.timestamp.batch: + state, + event, + ) and self.last_checkpoint_batch != state.timestamp.batch: self._save_checkpoint(state, logger) elif event == Event.INIT: if not isinstance(state.model, HuggingFaceModel): raise ValueError( f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. ' - + f'Got {type(state.model)} instead.') + + f'Got {type(state.model)} instead.', + ) if self.remote_ud is not None: self.remote_ud.init(state, logger) state.callbacks.append(self.remote_ud) @@ -230,12 +257,13 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: raise ValueError( f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. ' + - 'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.' + 'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.', ) import mlflow mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( - '1GB') + '1GB', + ) elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. timeout = 3600 @@ -244,7 +272,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: wait_time = time.time() - wait_start if wait_time > timeout: raise TimeoutError( - f'Waited {wait_time} seconds for child processes to complete. Exceeded timeout of {timeout} seconds.' + f'Waited {wait_time} seconds for child processes to complete. Exceeded timeout of {timeout} seconds.', ) time.sleep(2) @@ -261,7 +289,8 @@ def _is_last_batch(self, state: State): epoch_complete = state.dataloader_len == state.timestamp.batch_in_epoch second_to_last_epoch = state.max_duration.unit == TimeUnit.EPOCH and ( - state.timestamp.epoch == state.max_duration.value - 1) + state.timestamp.epoch == state.max_duration.value - 1 + ) # If the save interval is specified as exactly the same number of batches as the total duration, # but the max duration is specified in epochs, we need a special case to identify we are on the last batch # and should write the mlflow checkpoint. This should occur on the last batch of the final epoch. @@ -273,7 +302,8 @@ def _is_last_batch(self, state: State): if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH: assert state.dataloader_len is not None # for pyright return int(state.timestamp.batch) % math.ceil( - state.max_duration.value * state.dataloader_len) == 0 + state.max_duration.value * state.dataloader_len, + ) == 0 return False @@ -284,7 +314,9 @@ def _all_child_processes_done(self) -> bool: return x.item() == 0 def transform_model_and_tokenizer( - self, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: """Transform the model and tokenizer before saving. @@ -315,8 +347,11 @@ def _save_checkpoint(self, state: State, logger: Logger): save_dir = format_name_with_dist_and_time( str( Path(self.save_dir_format_str) / - self.huggingface_folder_name_fstr), state.run_name, - state.timestamp) + self.huggingface_folder_name_fstr, + ), + state.run_name, + state.timestamp, + ) # Use a temporary directory if save_dir is remote. use_temp_dir = self.remote_ud is not None @@ -344,7 +379,9 @@ def _save_checkpoint(self, state: State, logger: Logger): if version.parse(torch.__version__) > version.parse('2.2.9'): from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, get_model_state_dict) + StateDictOptions, + get_model_state_dict, + ) cpu_offload = True # Add a dtensor->cpu tensor hook to avoid CUDA OOM @@ -373,20 +410,26 @@ def dtensor_to_tensor_hook( for _, module in state_dict_model.named_modules(): if isinstance(module, FSDP): hooks.append( - module._register_state_dict_hook( - dtensor_to_tensor_hook)) + module. + _register_state_dict_hook(dtensor_to_tensor_hook), + ) - state_dict = get_model_state_dict(state_dict_model, - options=StateDictOptions( - full_state_dict=True, - cpu_offload=cpu_offload)) + state_dict = get_model_state_dict( + state_dict_model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=cpu_offload, + ), + ) for hook in hooks: hook.remove() else: state_dict_context = fsdp_state_dict_type_context( - original_model, state_dict_type='full') if ( - (not state.is_model_ddp) and isinstance( - state_dict_model, FSDP)) else contextlib.nullcontext() + original_model, + state_dict_type='full', + ) if ((not state.is_model_ddp) and + isinstance(state_dict_model, + FSDP)) else contextlib.nullcontext() with state_dict_context: state_dict = state_dict_model.state_dict() @@ -419,7 +462,8 @@ def dtensor_to_tensor_hook( new_model_instance = type(original_model)( new_base_model_instance, - original_model.peft_config[active_adapter]) + original_model.peft_config[active_adapter], + ) new_model_instance.to(dtype=self.dtype) else: # First create the model instance on meta device to avoid the @@ -427,7 +471,8 @@ def dtensor_to_tensor_hook( with init_empty_weights(): new_model_instance = type(original_model)(copied_config) new_model_instance.generation_config.update( - **original_model.generation_config.to_dict()) + **original_model.generation_config.to_dict(), + ) # Then load the state dict in with "assign" so that the state dict # is loaded properly even though the model is initially on meta device. @@ -436,7 +481,9 @@ def dtensor_to_tensor_hook( # Transform the model and tokenizer before saving new_model_instance, original_tokenizer = self.transform_model_and_tokenizer( - new_model_instance, original_tokenizer) + new_model_instance, + original_tokenizer, + ) log.debug('Saving Hugging Face checkpoint to disk') new_model_instance.save_pretrained(temp_save_dir) @@ -456,9 +503,10 @@ def dtensor_to_tensor_hook( for filename in os.listdir(temp_save_dir): remote_file_name = os.path.join(save_dir, filename) remote_file_uri = self.remote_ud.remote_backend.get_uri( - remote_file_name) + remote_file_name, + ) log.info( - f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' + f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}', ) self.remote_ud.upload_file( state=state, @@ -478,21 +526,22 @@ def dtensor_to_tensor_hook( log.debug('Logging Hugging Face model to MLFlow') for i, mlflow_logger in enumerate(self.mlflow_loggers): log.debug( - f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}' + f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}', ) local_save_path = str( - Path(temp_save_dir) / f'mlflow_save_{i}') + Path(temp_save_dir) / f'mlflow_save_{i}', + ) # TODO: Remove after mlflow fixes the bug that makes this necessary import mlflow mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' model_saving_kwargs: Dict[str, Any] = { - 'path': local_save_path + 'path': local_save_path, } if composer_model.using_peft: model_saving_kwargs['flavor'] = 'peft' - model_saving_kwargs[ - 'save_pretrained_dir'] = temp_save_dir + model_saving_kwargs['save_pretrained_dir' + ] = temp_save_dir model_saving_kwargs[ 'metadata'] = self.mlflow_logging_config['metadata'] else: @@ -506,7 +555,10 @@ def dtensor_to_tensor_hook( license_filename = _maybe_get_license_filename( local_save_path, self.mlflow_logging_config['metadata'].get( - 'pretrained_model_name', None)) + 'pretrained_model_name', + None, + ), + ) if license_filename is not None: mlflow_logger._mlflow_client.log_artifact( mlflow_logger._run_id, @@ -527,7 +579,8 @@ def dtensor_to_tensor_hook( self.mlflow_registered_model_name, 'await_creation_for': 3600, - }) + }, + ) process.start() self.child_processes.append(process) diff --git a/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py b/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py index 89ee37cf0c..3af44fe4f2 100644 --- a/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py +++ b/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py @@ -77,7 +77,7 @@ def fit_start(self, state: State, logger: Logger) -> None: from megablocks.layers.moe import MoE except: raise RuntimeError( - 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.' + 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.', ) for module in state.model.modules(): if isinstance(module, (MoE, dMoE)): @@ -85,7 +85,7 @@ def fit_start(self, state: State, logger: Logger) -> None: return raise RuntimeError( - f'Callback not initialized correctly; self.topk not instantiated.' + f'Callback not initialized correctly; self.topk not instantiated.', ) def batch_end(self, state: State, logger: Logger) -> None: @@ -94,7 +94,7 @@ def batch_end(self, state: State, logger: Logger) -> None: from megablocks.layers.moe import get_load_balancing_loss except: raise RuntimeError( - 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.' + 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.', ) tokens_per_expert, _ = zip(*get_load_balancing_loss()) diff --git a/llmfoundry/callbacks/monolithic_ckpt_callback.py b/llmfoundry/callbacks/monolithic_ckpt_callback.py index 395a13111c..05f3b6969b 100644 --- a/llmfoundry/callbacks/monolithic_ckpt_callback.py +++ b/llmfoundry/callbacks/monolithic_ckpt_callback.py @@ -8,12 +8,18 @@ import torch from composer.core import Callback, State -from composer.core.state import (fsdp_get_optim_state_dict, - fsdp_state_dict_type_context) +from composer.core.state import ( + fsdp_get_optim_state_dict, + fsdp_state_dict_type_context, +) from composer.loggers import Logger from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader -from composer.utils import (dist, format_name_with_dist_and_time, parse_uri, - reproducibility) +from composer.utils import ( + dist, + format_name_with_dist_and_time, + parse_uri, + reproducibility, +) __all__ = ['MonolithicCheckpointSaver'] @@ -29,14 +35,17 @@ class MonolithicCheckpointSaver(Callback): keep_optimizers (bool): Whether to save the optimizer state in the monolithic checkpoint. """ - def __init__(self, - save_folder: str, - batch_interval: int, - filename: str = 'ep{epoch}-ba{batch}.pt', - overwrite: bool = False, - keep_optimizers: bool = False): + def __init__( + self, + save_folder: str, + batch_interval: int, + filename: str = 'ep{epoch}-ba{batch}.pt', + overwrite: bool = False, + keep_optimizers: bool = False, + ): self.backend, self.bucket_name, self.save_dir_format_str = parse_uri( - save_folder) + save_folder, + ) self.filename_format_str = filename self.batch_interval = batch_interval self.upload_to_object_store = (self.backend != '') @@ -44,7 +53,8 @@ def __init__(self, self.keep_optimizers = keep_optimizers if self.upload_to_object_store: self.remote_ud = RemoteUploaderDownloader( - bucket_uri=f'{self.backend}://{self.bucket_name}') + bucket_uri=f'{self.backend}://{self.bucket_name}', + ) else: self.remote_ud = None @@ -66,15 +76,20 @@ def fit_end(self, state: State, logger: Logger) -> None: def _save_checkpoint(self, state: State, logger: Logger) -> None: del logger # unused - filename = format_name_with_dist_and_time(self.filename_format_str, - state.run_name, - state.timestamp) - save_dir = format_name_with_dist_and_time(self.save_dir_format_str, - state.run_name, - state.timestamp) + filename = format_name_with_dist_and_time( + self.filename_format_str, + state.run_name, + state.timestamp, + ) + save_dir = format_name_with_dist_and_time( + self.save_dir_format_str, + state.run_name, + state.timestamp, + ) dir_context_mgr = tempfile.TemporaryDirectory( ) if self.upload_to_object_store else contextlib.nullcontext( - enter_result=save_dir) + enter_result=save_dir, + ) with dir_context_mgr as temp_save_dir: # pyright doesn't know about enter_result assert isinstance(temp_save_dir, str) @@ -85,15 +100,17 @@ def _save_checkpoint(self, state: State, logger: Logger) -> None: os.makedirs(dirname, exist_ok=True) state_dict = { 'state': state.state_dict(), - 'rng': reproducibility.get_rng_state() + 'rng': reproducibility.get_rng_state(), } # Remove sharded model and optimizer state dicts state_dict['state'].pop('optimizers') state_dict['state'].pop('model') # Add in unsharded model params. - with fsdp_state_dict_type_context(state.model, - state_dict_type='full'): + with fsdp_state_dict_type_context( + state.model, + state_dict_type='full', + ): state_dict['state']['model'] = state.model.state_dict() # Add in unsharded optimizer state dict. @@ -101,9 +118,11 @@ def _save_checkpoint(self, state: State, logger: Logger) -> None: optimizer = state.optimizers[0] state_dict['state']['optimizers'] = { type(optimizer).__qualname__: - fsdp_get_optim_state_dict(state.model, - optimizer, - state_dict_type='full') + fsdp_get_optim_state_dict( + state.model, + optimizer, + state_dict_type='full', + ), } if dist.get_global_rank() == 0: torch.save(state_dict, save_path) @@ -111,7 +130,9 @@ def _save_checkpoint(self, state: State, logger: Logger) -> None: if self.upload_to_object_store and self.remote_ud is not None and dist.get_global_rank( ) == 0: remote_file_name = str(Path(save_dir) / Path(filename)) - self.remote_ud.upload_file(state=state, - remote_file_name=remote_file_name, - file_path=Path(save_path), - overwrite=self.overwrite) + self.remote_ud.upload_file( + state=state, + remote_file_name=remote_file_name, + file_path=Path(save_path), + overwrite=self.overwrite, + ) diff --git a/llmfoundry/callbacks/resumption_callbacks.py b/llmfoundry/callbacks/resumption_callbacks.py index f910114a88..509d1595bd 100644 --- a/llmfoundry/callbacks/resumption_callbacks.py +++ b/llmfoundry/callbacks/resumption_callbacks.py @@ -47,7 +47,8 @@ def fit_start(self, state: State, logger: Logger) -> None: if 'initial_lr' in group: group['initial_lr'] *= self.lr_scale log.info( - f"Set LR and WD to {group['lr']}, {group['weight_decay']}") + f"Set LR and WD to {group['lr']}, {group['weight_decay']}", + ) for scheduler in state.schedulers: scheduler.base_lrs = [ @@ -74,11 +75,11 @@ def __init__(self, layer_names: List[str]): def fit_start(self, state: State, logger: Logger) -> None: del logger # unused - model_layers = set(name for name, _ in state.model.named_parameters()) + model_layers = {name for name, _ in state.model.named_parameters()} for layer in self.layer_names: if layer not in model_layers: raise Exception( - f'Attempted to freeze layer not found in model: {layer}\nAvailable layers: {model_layers}' + f'Attempted to freeze layer not found in model: {layer}\nAvailable layers: {model_layers}', ) successful_freeze = False @@ -90,4 +91,5 @@ def fit_start(self, state: State, logger: Logger) -> None: if not successful_freeze: raise Exception( - f"Tried to run LayerFreezing but didn't freeze any layers") + f"Tried to run LayerFreezing but didn't freeze any layers", + ) diff --git a/llmfoundry/cli/registry_cli.py b/llmfoundry/cli/registry_cli.py index 03046c2f07..38ada51fd9 100644 --- a/llmfoundry/cli/registry_cli.py +++ b/llmfoundry/cli/registry_cli.py @@ -23,7 +23,7 @@ def _get_registries(group: Optional[str] = None) -> list[TypedRegistry]: if group is not None and group not in registry_attr_names: console.print( - f'Group {group} not found in registry. Run `llmfoundry registry get` to see available groups.' + f'Group {group} not found in registry. Run `llmfoundry registry get` to see available groups.', ) return [] @@ -44,8 +44,11 @@ def get(group: Optional[str] = None): table = Table('Registry', 'Description', 'Options', show_lines=True) for r in available_registries: - table.add_row('.'.join(r.namespace), r.description, - ', '.join(r.get_all())) + table.add_row( + '.'.join(r.namespace), + r.description, + ', '.join(r.get_all()), + ) console.print(table) @@ -66,7 +69,11 @@ def find(group: str, name: str): find_output = r.find(name) table = Table('Module', 'File', 'Line number', 'Docstring') - table.add_row(find_output['module'], find_output['file'], - str(find_output['line_no']), find_output['docstring']) + table.add_row( + find_output['module'], + find_output['file'], + str(find_output['line_no']), + find_output['docstring'], + ) console.print(table) diff --git a/llmfoundry/data/__init__.py b/llmfoundry/data/__init__.py index 027ea7b07a..51635b4586 100644 --- a/llmfoundry/data/__init__.py +++ b/llmfoundry/data/__init__.py @@ -3,15 +3,22 @@ from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset from llmfoundry.data.dataloader import build_dataloader -from llmfoundry.data.finetuning import (Seq2SeqFinetuningCollator, - StreamingFinetuningDataset, - build_finetuning_dataloader) -from llmfoundry.data.packing import (BinPackCollator, auto_packing_ratio, - profile_packing) -from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, - StreamingTextDataset, - build_text_dataloader, - get_tokens_per_batch_func) +from llmfoundry.data.finetuning import ( + Seq2SeqFinetuningCollator, + StreamingFinetuningDataset, + build_finetuning_dataloader, +) +from llmfoundry.data.packing import ( + BinPackCollator, + auto_packing_ratio, + profile_packing, +) +from llmfoundry.data.text_data import ( + ConcatenatedSequenceCollatorWrapper, + StreamingTextDataset, + build_text_dataloader, + get_tokens_per_batch_func, +) from llmfoundry.registry import dataloaders dataloaders.register('text', func=build_text_dataloader) diff --git a/llmfoundry/data/data.py b/llmfoundry/data/data.py index c7b018c5fb..482c296fa5 100644 --- a/llmfoundry/data/data.py +++ b/llmfoundry/data/data.py @@ -23,8 +23,10 @@ class NoConcatDataset(IterableDataset): Returns dicts of {'text': bytes} """ - def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset, - hf_datasets.Dataset]): + def __init__( + self, + hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset], + ): self.hf_dataset = hf_dataset def __iter__(self) -> Iterable[Dict[str, bytes]]: @@ -73,44 +75,54 @@ def __init__( self.eos_text = eos_text self.should_wrap = not no_wrap - self.bos_tokens = self.tokenizer(self.bos_text, - truncation=False, - padding=False, - add_special_tokens=False)['input_ids'] + self.bos_tokens = self.tokenizer( + self.bos_text, + truncation=False, + padding=False, + add_special_tokens=False, + )['input_ids'] if len(self.bos_tokens) > 1: warnings.warn( f'You specified --concat_tokens with --bos_text, but your BOS text is not tokenizing to one token\ - , instead we got {self.bos_tokens}. Quit if this was in error.') + , instead we got {self.bos_tokens}. Quit if this was in error.', + ) - self.eos_tokens = self.tokenizer(self.eos_text, - truncation=False, - padding=False, - add_special_tokens=False)['input_ids'] + self.eos_tokens = self.tokenizer( + self.eos_text, + truncation=False, + padding=False, + add_special_tokens=False, + )['input_ids'] if len(self.eos_tokens) > 1: warnings.warn( f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token\ - , instead we got {self.eos_tokens}. Quit if this was in error.') + , instead we got {self.eos_tokens}. Quit if this was in error.', + ) eos_text_provided = self.eos_text != '' bos_text_provided = self.bos_text != '' test_text = self.tokenizer('') - if len(test_text['input_ids']) > 0 and (eos_text_provided or - bos_text_provided): + if len( + test_text['input_ids'], + ) > 0 and (eos_text_provided or bos_text_provided): message = 'both eos and bos' if eos_text_provided and bos_text_provided else ( - 'eos_text' if eos_text_provided else 'bos_text') + 'eos_text' if eos_text_provided else 'bos_text' + ) warnings.warn( f'The provided tokenizer adds special tokens, but you also specified {message}. This may result ' + - 'in duplicated special tokens. Please be sure this is what you intend.' + 'in duplicated special tokens. Please be sure this is what you intend.', ) def __iter__(self) -> Iterable[Dict[str, bytes]]: buffer = [] for sample in self.hf_dataset: - encoded = self.tokenizer(sample['text'], - truncation=False, - padding=False) + encoded = self.tokenizer( + sample['text'], + truncation=False, + padding=False, + ) iids = encoded['input_ids'] buffer = buffer + self.bos_tokens + iids + self.eos_tokens while len(buffer) >= self.max_length: @@ -118,5 +130,5 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: buffer = buffer[self.max_length:] if self.should_wrap else [] yield { # convert to bytes to store in MDS binary format - 'tokens': np.asarray(concat_sample).tobytes() + 'tokens': np.asarray(concat_sample).tobytes(), } diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py index 61471420f8..1a808d1a3e 100644 --- a/llmfoundry/data/dataloader.py +++ b/llmfoundry/data/dataloader.py @@ -15,8 +15,11 @@ ] -def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: int) -> DataSpec: +def build_dataloader( + cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, + device_batch_size: int, +) -> DataSpec: """Builds a dataloader from a config. Args: @@ -28,7 +31,7 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, kwargs = { 'cfg': cfg, 'tokenizer': tokenizer, - 'device_batch_size': device_batch_size + 'device_batch_size': device_batch_size, } return construct_from_registry( diff --git a/llmfoundry/data/finetuning/__init__.py b/llmfoundry/data/finetuning/__init__.py index 3b5c277199..5d891d546c 100644 --- a/llmfoundry/data/finetuning/__init__.py +++ b/llmfoundry/data/finetuning/__init__.py @@ -3,10 +3,12 @@ from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader -from llmfoundry.data.finetuning.tasks import (StreamingFinetuningDataset, - dataset_constructor, - is_valid_ift_example, - tokenize_formatted_example) +from llmfoundry.data.finetuning.tasks import ( + StreamingFinetuningDataset, + dataset_constructor, + is_valid_ift_example, + tokenize_formatted_example, +) __all__ = [ 'Seq2SeqFinetuningCollator', diff --git a/llmfoundry/data/finetuning/collator.py b/llmfoundry/data/finetuning/collator.py index 7d592483f1..42af7e9375 100644 --- a/llmfoundry/data/finetuning/collator.py +++ b/llmfoundry/data/finetuning/collator.py @@ -27,18 +27,21 @@ def ensure_list(x: Union[List, torch.Tensor]) -> List: return x -def validate_target_settings(target_prompts: str, target_responses: str, - decoder_only_format: bool): +def validate_target_settings( + target_prompts: str, + target_responses: str, + decoder_only_format: bool, +): """Raises an error if target settings are invalid.""" - if (not decoder_only_format) and (target_prompts != 'none' or - target_responses != 'last'): + if (not decoder_only_format + ) and (target_prompts != 'none' or target_responses != 'last'): raise ValueError( - f'When using encoder_decoder format, you must use target_prompts="none" and target_responses="last".' + f'When using encoder_decoder format, you must use target_prompts="none" and target_responses="last".', ) if target_responses not in {'all', 'last'}: raise ValueError( - f'target_responses must be either "last" or "all" but {target_responses=}' + f'target_responses must be either "last" or "all" but {target_responses=}', ) if target_prompts.startswith('length>='): @@ -47,37 +50,43 @@ def validate_target_settings(target_prompts: str, target_responses: str, raise ValueError( f'target_prompts starts with "length>=" but the rest of the string is not digits ({target_prompts=}). ' +\ 'To use this configuration option, set target_prompts "length>=XX" where "XX" is a positive integer indicating ' +\ - 'the length cutoff. Prompts of at least XX tokens in length will be treated as targets.' + 'the length cutoff. Prompts of at least XX tokens in length will be treated as targets.', ) cutoff = int(cutoff) if cutoff <= 0: raise ValueError( - f'You are trying to set the target_prompts length cutoff to a negative number {cutoff=}. This is not allowed.' + f'You are trying to set the target_prompts length cutoff to a negative number {cutoff=}. This is not allowed.', ) elif target_prompts not in {'all', 'none'}: raise ValueError( - f'target_prompts must either be "all", "none" or "length>=XX" where "XX" is a positive integer, but {target_prompts=}' + f'target_prompts must either be "all", "none" or "length>=XX" where "XX" is a positive integer, but {target_prompts=}', ) ###### Functions to implement target_prompts and target_responses choices ##### -def _sequence_to_labels_all(sequence: list[int], - is_last_turn: bool, - cutoff: Optional[int] = None) -> list[int]: +def _sequence_to_labels_all( + sequence: list[int], + is_last_turn: bool, + cutoff: Optional[int] = None, +) -> list[int]: del is_last_turn, cutoff # unused return sequence -def _sequence_to_labels_none(sequence: list[int], - is_last_turn: bool, - cutoff: Optional[int] = None) -> list[int]: +def _sequence_to_labels_none( + sequence: list[int], + is_last_turn: bool, + cutoff: Optional[int] = None, +) -> list[int]: del is_last_turn, cutoff # unused return [_HF_IGNORE_INDEX] * len(sequence) -def _sequence_to_labels_last(sequence: list[int], - is_last_turn: bool, - cutoff: Optional[int] = None) -> list[int]: +def _sequence_to_labels_last( + sequence: list[int], + is_last_turn: bool, + cutoff: Optional[int] = None, +) -> list[int]: del cutoff # unused if is_last_turn: return sequence @@ -85,9 +94,11 @@ def _sequence_to_labels_last(sequence: list[int], return [_HF_IGNORE_INDEX] * len(sequence) -def _sequence_to_labels_cutoff(sequence: list[int], - is_last_turn: bool, - cutoff: Optional[int] = None) -> list[int]: +def _sequence_to_labels_cutoff( + sequence: list[int], + is_last_turn: bool, + cutoff: Optional[int] = None, +) -> list[int]: del is_last_turn # unused if cutoff is None: raise ValueError('input ``cutoff`` must be provided') @@ -106,18 +117,21 @@ def _sequence_to_labels_cutoff(sequence: list[int], def stitch_turns_decoder_only( - example_turns: list[dict[str, list[int]]], - target_prompts: str, - target_responses: str, - eos_token_id: Optional[int] = None, - validate: bool = False) -> tuple[list[int], list[int]]: + example_turns: list[dict[str, list[int]]], + target_prompts: str, + target_responses: str, + eos_token_id: Optional[int] = None, + validate: bool = False, +) -> tuple[list[int], list[int]]: target_prompts = target_prompts.lower() target_responses = target_responses.lower() if validate: - validate_target_settings(target_prompts, - target_responses, - decoder_only_format=True) + validate_target_settings( + target_prompts, + target_responses, + decoder_only_format=True, + ) if target_prompts.startswith('length'): prompt_cutoff = int(target_prompts.split('>=')[-1]) @@ -148,7 +162,7 @@ def stitch_turns_decoder_only( if len(input_ids) != len(labels): raise ValueError( - f'input_ids and labels should be the same length, {len(input_ids)=}, {len(labels)=}' + f'input_ids and labels should be the same length, {len(input_ids)=}, {len(labels)=}', ) return input_ids, labels @@ -240,29 +254,32 @@ def __init__( 'decoder_input_ids', 'decoder_attention_mask', ] - found_keys = [] - for illegal_key in illegal_keys: - if illegal_key in self.batch_metadata: - found_keys.append(illegal_key) + found_keys = [ + illegal_key for illegal_key in illegal_keys + if illegal_key in self.batch_metadata + ] if found_keys: raise ValueError( f'The following keys are in batch_metadata but are not allowed: {", ".join(found_keys)}.\n' +\ f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' +\ - f'{", ".join(illegal_keys)}' + f'{", ".join(illegal_keys)}', ) if (max_seq_len % 8) != 0: log.warning( - 'For performance, a max_seq_len as a multiple of 8 is recommended.' + 'For performance, a max_seq_len as a multiple of 8 is recommended.', ) if self.tokenizer.pad_token_id is None: raise ValueError( - f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None' + f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None', ) - validate_target_settings(self.target_prompts, self.target_responses, - self.decoder_only_format) + validate_target_settings( + self.target_prompts, + self.target_responses, + self.decoder_only_format, + ) if self.target_prompts.startswith('length'): self.prompt_cutoff = int(self.target_prompts.split('>=')[-1]) self.prompt_to_target = _TARGET_POLICY_LOOKUP['length'] @@ -280,7 +297,7 @@ def __call__(self, for check_key in ['input_ids', 'labels']: if check_key not in examples[0]['turns'][0]: raise KeyError( - f'Examples returned by dataset do not include required key: {check_key}' + f'Examples returned by dataset do not include required key: {check_key}', ) if self.decoder_only_format: @@ -298,7 +315,9 @@ def __call__(self, return batch def _process_and_batch_decoder_only( - self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]: + self, + examples: List[TokenizedExample], + ) -> Dict[str, torch.Tensor]: # Steps explained in comments processed_examples = [] for example in examples: @@ -323,14 +342,14 @@ def _process_and_batch_decoder_only( 'This sample should have been filtered out before reaching the collator. If using ' +\ 'pre-tokenized streaming data, this may have resulted from using different ' +\ '``target_prompts``, ``target_responses``, or ``max_seq_len`` ' +\ - 'settings when preparing the streaming dataset than what are currently being used.' + 'settings when preparing the streaming dataset than what are currently being used.', ) # Still issue a warning when truncating if not self._warned_truncated: warnings.warn( f'Truncating sequence of length={orig_size} to fit max_seq_len={self.max_seq_len}. ' +\ - f'If truncation is a problem, consider increasing max_seq_len.' + f'If truncation is a problem, consider increasing max_seq_len.', ) self._warned_truncated = True @@ -385,7 +404,9 @@ def _process_and_batch_decoder_only( return batch def _process_and_batch_encoder_decoder( - self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]: + self, + examples: List[TokenizedExample], + ) -> Dict[str, torch.Tensor]: # The encoder-decoder case is has some gotchas. # Steps are explained in comments. processed_examples = [] @@ -439,14 +460,18 @@ def _process_and_batch_encoder_decoder( # We're still missing decoder_input_ids and decoder_attention_mask batch['decoder_input_ids'] = torch.cat([ torch.full((len(processed_examples), 1), - self.tokenizer.pad_token_id), batch['labels'][:, :-1] + self.tokenizer.pad_token_id), + batch['labels'][:, :-1], ], dim=1) batch['decoder_input_ids'].masked_fill_( batch['decoder_input_ids'] == _HF_IGNORE_INDEX, - self.tokenizer.pad_token_id) + self.tokenizer.pad_token_id, + ) batch['decoder_attention_mask'] = torch.not_equal( - batch['labels'], _HF_IGNORE_INDEX) + batch['labels'], + _HF_IGNORE_INDEX, + ) # This logic prevents trimming on at least the first batch if not (self._allow_pad_trimming and self._seen_first_batch): diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index e72ee29719..81bdd5fc8f 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -11,15 +11,21 @@ from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase -from llmfoundry.data.finetuning.collator import (Seq2SeqFinetuningCollator, - validate_target_settings) -from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH, - SUPPORTED_EXTENSIONS, - dataset_constructor) +from llmfoundry.data.finetuning.collator import ( + Seq2SeqFinetuningCollator, + validate_target_settings, +) +from llmfoundry.data.finetuning.tasks import ( + DOWNLOADED_FT_DATASETS_DIRPATH, + SUPPORTED_EXTENSIONS, + dataset_constructor, +) from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio from llmfoundry.data.text_data import build_streams, get_tokens_per_batch_func -from llmfoundry.utils.exceptions import (MissingHuggingFaceURLSplitError, - NotEnoughDatasetSamplesError) +from llmfoundry.utils.exceptions import ( + MissingHuggingFaceURLSplitError, + NotEnoughDatasetSamplesError, +) log = logging.getLogger(__name__) @@ -35,9 +41,11 @@ _DEFAULT_TARGET_PROMPTS = 'none' -def build_finetuning_dataloader(cfg: DictConfig, - tokenizer: PreTrainedTokenizerBase, - device_batch_size: int) -> DataSpec: +def build_finetuning_dataloader( + cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, + device_batch_size: int, +) -> DataSpec: """Builds a finetuning dataloader for training or evaluating. The underlying dataset can be built through one of two code paths: @@ -142,12 +150,16 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer.pad_token = tokenizer.eos_token collate_fn, dataloader_batch_size = _build_collate_fn( - cfg, tokenizer, device_batch_size) + cfg, + tokenizer, + device_batch_size, + ) dataset = None # for pyright sampler = None - if cfg.dataset.get('remote') is not None or cfg.dataset.get( - 'streams') is not None: + if cfg.dataset.get( + 'remote', + ) is not None or cfg.dataset.get('streams') is not None: # Build streaming dataloader streams = build_streams(cfg.dataset) dataset = dataset_constructor.build_from_streaming( @@ -189,17 +201,22 @@ def build_finetuning_dataloader(cfg: DictConfig, backend, _, _ = parse_uri(dataset_name_or_path) if backend not in ['', None]: dataset_name_or_path = _download_remote_hf_dataset( - remote_path=dataset_name_or_path, split=split) + remote_path=dataset_name_or_path, + split=split, + ) split = split.replace('-', '_') # Get the preprocessing function. proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn') if isinstance(proto_preprocessing_fn, (dict, DictConfig)): preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict( - dict(proto_preprocessing_fn)) + dict(proto_preprocessing_fn), + ) else: preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str( - proto_preprocessing_fn, dataset_name_or_path) + proto_preprocessing_fn, + dataset_name_or_path, + ) # Build dataset from HF. dataset = dataset_constructor.build_from_hf( @@ -209,12 +226,17 @@ def build_finetuning_dataloader(cfg: DictConfig, max_seq_len=cfg.dataset.max_seq_len, preprocessing_fn=preprocessing_fn, tokenizer=tokenizer, - target_prompts=cfg.dataset.get('target_prompts', - _DEFAULT_TARGET_PROMPTS), - target_responses=cfg.dataset.get('target_responses', - _DEFAULT_TARGET_RESPONSES), + target_prompts=cfg.dataset.get( + 'target_prompts', + _DEFAULT_TARGET_PROMPTS, + ), + target_responses=cfg.dataset.get( + 'target_responses', + _DEFAULT_TARGET_RESPONSES, + ), decoder_only_format=cfg.dataset.decoder_only_format, - hf_kwargs=cfg.dataset.get('hf_kwargs', {})) + hf_kwargs=cfg.dataset.get('hf_kwargs', {}), + ) # Ensure dataset is large enough. if cfg.drop_last: @@ -229,11 +251,14 @@ def build_finetuning_dataloader(cfg: DictConfig, dataloader_batch_size=dataloader_batch_size, world_size=world_size, full_dataset_size=full_dataset_size, - minimum_dataset_size=minimum_dataset_size) + minimum_dataset_size=minimum_dataset_size, + ) # Initialize sampler. - sampler = dist.get_sampler(dataset, - drop_last=cfg.drop_last, - shuffle=cfg.dataset.shuffle) + sampler = dist.get_sampler( + dataset, + drop_last=cfg.drop_last, + shuffle=cfg.dataset.shuffle, + ) assert dataset is not None # for pyright dl = DataLoader( @@ -270,16 +295,17 @@ def _validate_config(dataset_cfg: DictConfig) -> None: if dataset_cfg.get('hf_name') is not None: # Using the HuggingFace dataset codepath illegal_keys = ['local', 'remote'] - discovered_illegal_keys = [] - for key in illegal_keys: - if dataset_cfg.get(key) is not None: - discovered_illegal_keys.append('`' + key + '`') + discovered_illegal_keys = [ + '`' + key + '`' + for key in illegal_keys + if dataset_cfg.get(key) is not None + ] if discovered_illegal_keys: raise ValueError( 'The dataset config sets a value for `hf_name` as well as the ' +\ f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\ 'Those keys are used when building from a streaming dataset, but ' +\ - 'setting `hf_name` instructs the dataset to build from a HuggingFace dataset.' + 'setting `hf_name` instructs the dataset to build from a HuggingFace dataset.', ) elif dataset_cfg.get('remote') is not None: # Using the streaming dataset codepath @@ -293,12 +319,12 @@ def _validate_config(dataset_cfg: DictConfig) -> None: 'The dataset config sets a value for `remote` as well as the ' +\ f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\ 'Those keys are used when building from a HuggingFace dataset, but ' +\ - 'setting `remote` instructs the dataset to build from a streaming dataset.' + 'setting `remote` instructs the dataset to build from a streaming dataset.', ) if dataset_cfg.get('local') is None: raise ValueError( 'Using a streaming dataset requires setting both `remote` and `local`, ' +\ - 'but dataset.local is None.' + 'but dataset.local is None.', ) elif dataset_cfg.get('streams') is not None: # Using the streaming dataset codepath @@ -312,7 +338,7 @@ def _validate_config(dataset_cfg: DictConfig) -> None: 'The dataset config sets a value for `streams` as well as the ' +\ f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\ 'Those keys are used when building from a HuggingFace dataset, but ' +\ - 'setting `streams` instructs the dataset to build from a streaming dataset.' + 'setting `streams` instructs the dataset to build from a streaming dataset.', ) illegal_keys = ['remote', 'local'] discovered_illegal_keys = [] @@ -324,28 +350,34 @@ def _validate_config(dataset_cfg: DictConfig) -> None: 'The dataset config sets a value for `streams` as well as the ' +\ f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\ 'Please either use single stream (set remote/local only) ' +\ - 'or put remote/local under streams' + 'or put remote/local under streams', ) else: raise ValueError( 'In the dataset config, you must set `hf_name` to use a HuggingFace ' +\ 'dataset, or set `remote` to use a streaming dataset, or set ' +\ - '`streams` to use multiple streaming datasets, but all were None.' + '`streams` to use multiple streaming datasets, but all were None.', ) if dataset_cfg.get('max_seq_len') is None: raise ValueError( - 'In the dataset config, you must set the `max_seq_len`') + 'In the dataset config, you must set the `max_seq_len`', + ) # Raise an error if the target_prompts + target_responses + decoder_only_format settings # are invalid target_responses = str( - dataset_cfg.get('target_responses', _DEFAULT_TARGET_RESPONSES)).lower() + dataset_cfg.get('target_responses', _DEFAULT_TARGET_RESPONSES), + ).lower() target_prompts = str( - dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS)).lower() + dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS), + ).lower() decoder_only_format = dataset_cfg.decoder_only_format - validate_target_settings(target_prompts, target_responses, - decoder_only_format) + validate_target_settings( + target_prompts, + target_responses, + decoder_only_format, + ) def _download_remote_hf_dataset(remote_path: str, split: str) -> str: @@ -381,13 +413,19 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str: destination = str( os.path.abspath( os.path.join( - finetune_dir, 'data', - f'{hf_formatted_split}-00000-of-00001{extension}'))) + finetune_dir, + 'data', + f'{hf_formatted_split}-00000-of-00001{extension}', + ), + ), + ) # Since we don't know exactly what the extension will be, since it is one of a list # use a signal file to wait for instead of the desired file signal_file_path = os.path.join( - finetune_dir, f'.node_{dist.get_node_rank()}_local_rank0_completed') + finetune_dir, + f'.node_{dist.get_node_rank()}_local_rank0_completed', + ) if dist.get_local_rank() == 0: try: get_file(path=name, destination=destination, overwrite=True) @@ -400,11 +438,12 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str: raise FileNotFoundError( f'Could not find a file with any of ' + \ f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \ - f'at {files_searched}' + f'at {files_searched}', ) from e else: log.debug( - f'Could not find {name}, looking for another extension') + f'Could not find {name}, looking for another extension', + ) continue os.makedirs(os.path.dirname(signal_file_path), exist_ok=True) @@ -426,8 +465,9 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str: def _build_collate_fn( - dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: int + dataloader_cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, + device_batch_size: int, ) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]: dataset_cfg = dataloader_cfg.dataset max_seq_len = dataset_cfg.max_seq_len @@ -436,10 +476,14 @@ def _build_collate_fn( tokenizer=tokenizer, max_seq_len=max_seq_len, decoder_only_format=dataset_cfg.decoder_only_format, - target_responses=dataset_cfg.get('target_responses', - _DEFAULT_TARGET_RESPONSES), - target_prompts=dataset_cfg.get('target_prompts', - _DEFAULT_TARGET_PROMPTS), + target_responses=dataset_cfg.get( + 'target_responses', + _DEFAULT_TARGET_RESPONSES, + ), + target_prompts=dataset_cfg.get( + 'target_prompts', + _DEFAULT_TARGET_PROMPTS, + ), allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False), ) @@ -454,13 +498,17 @@ def _build_collate_fn( return collate_fn, device_batch_size if packing_ratio == 'auto': - packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer, - device_batch_size) + packing_ratio = auto_packing_ratio( + dataloader_cfg, + tokenizer, + device_batch_size, + ) if isinstance(packing_ratio, str): raise ValueError( 'dataset.packing_ratio must be a float or "auto", but it was set to ' - + f'{packing_ratio}.') + + f'{packing_ratio}.', + ) log.info(f'Using packing ratio {packing_ratio}') @@ -471,7 +519,7 @@ def _build_collate_fn( if not dataset_cfg.decoder_only_format: raise NotImplementedError( - 'On-the-fly packing is currently only supported for decoder-only formats.' + 'On-the-fly packing is currently only supported for decoder-only formats.', ) collate_fn = BinPackCollator( @@ -521,7 +569,7 @@ def _build_collate_fn( 'pin_memory': False, 'prefetch_factor': None, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, }) tokenizer_name = 'EleutherAI/gpt-neox-20b' @@ -529,8 +577,11 @@ def _build_collate_fn( tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) device_batch_size = 1 - dataloader = build_finetuning_dataloader(cfg, tokenizer, - device_batch_size).dataloader + dataloader = build_finetuning_dataloader( + cfg, + tokenizer, + device_batch_size, + ).dataloader packing = cfg.dataset.get('packing_ratio') is not None @@ -551,31 +602,43 @@ def _build_collate_fn( is_subseq = batch['sequence_id'][j] == subseq print( '\033[93m{}\033[00m\n'.format('INPUT IDS:'), - tokenizer.decode(batch['input_ids'][ - j, - torch.logical_and( - is_subseq, batch['attention_mask'][j] == - 1)], - skip_special_tokens=False, - clean_up_tokenization_spaces=True)) + tokenizer.decode( + batch['input_ids'][ + j, + torch.logical_and( + is_subseq, + batch['attention_mask'][j] == 1, + )], + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ), + ) context = torch.logical_and( batch['attention_mask'][j] == 1, - batch['labels'][j] == _HF_IGNORE_INDEX) + batch['labels'][j] == _HF_IGNORE_INDEX, + ) print( '\033[92m{}\033[00m\n'.format('CONTEXT: '), - tokenizer.decode(batch['input_ids'][ - j, torch.logical_and(is_subseq, context)], - skip_special_tokens=False, - clean_up_tokenization_spaces=True)) + tokenizer.decode( + batch['input_ids'][ + j, torch.logical_and(is_subseq, context)], + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ), + ) print( '\033[91m{}\033[00m\n'.format('TARGET: '), - tokenizer.decode(batch['input_ids'][ - j, - torch.logical_and( - is_subseq, - batch['labels'][j] != _HF_IGNORE_INDEX)], - skip_special_tokens=False, - clean_up_tokenization_spaces=True)) + tokenizer.decode( + batch['input_ids'][ + j, + torch.logical_and( + is_subseq, + batch['labels'][j] != _HF_IGNORE_INDEX, + )], + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ), + ) else: print( '\033[93m{}\033[00m\n'.format('INPUT IDS:'), @@ -583,32 +646,46 @@ def _build_collate_fn( batch['input_ids'][j, batch['attention_mask'][j] == 1], skip_special_tokens=False, - clean_up_tokenization_spaces=True)) + clean_up_tokenization_spaces=True, + ), + ) context = torch.logical_and( batch['attention_mask'][j] == 1, - batch['labels'][j] == _HF_IGNORE_INDEX) + batch['labels'][j] == _HF_IGNORE_INDEX, + ) print( '\033[92m{}\033[00m\n'.format('CONTEXT: '), - tokenizer.decode(batch['input_ids'][j, context], - skip_special_tokens=False, - clean_up_tokenization_spaces=True)) + tokenizer.decode( + batch['input_ids'][j, context], + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ), + ) print( '\033[91m{}\033[00m\n'.format('TARGET: '), - tokenizer.decode(batch['input_ids'][ - j, batch['labels'][j] != _HF_IGNORE_INDEX], - skip_special_tokens=False, - clean_up_tokenization_spaces=True)) + tokenizer.decode( + batch['input_ids'][ + j, batch['labels'][j] != _HF_IGNORE_INDEX], + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ), + ) else: print( '\033[92m{}\033[00m\n'.format('CONTEXT: '), tokenizer.decode( batch['input_ids'][j, batch['attention_mask'][j] == 1], skip_special_tokens=False, - clean_up_tokenization_spaces=True)) + clean_up_tokenization_spaces=True, + ), + ) print( '\033[91m{}\033[00m\n'.format('TARGET: '), - tokenizer.decode(batch['labels'][ - j, batch['decoder_attention_mask'][j] == 1], - skip_special_tokens=False, - clean_up_tokenization_spaces=True)) + tokenizer.decode( + batch['labels'][j, batch['decoder_attention_mask'][j] == + 1], + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ), + ) print(' ') diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 05a01b80c6..4c468a9d2c 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -38,8 +38,18 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from collections.abc import Mapping from functools import partial from pathlib import Path -from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence, - Tuple, Union, cast) +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, + cast, +) import datasets as hf_datasets import datasets.exceptions as hf_exceptions @@ -49,26 +59,30 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from streaming import Stream, StreamingDataset from transformers import PreTrainedTokenizerBase -from llmfoundry.data.finetuning.collator import (_HF_IGNORE_INDEX, - stitch_turns_decoder_only, - stitch_turns_encoder_decoder) +from llmfoundry.data.finetuning.collator import ( + _HF_IGNORE_INDEX, + stitch_turns_decoder_only, + stitch_turns_encoder_decoder, +) # yapf: disable -from llmfoundry.utils.exceptions import (ALLOWED_MESSAGES_KEYS, - ALLOWED_PROMPT_KEYS, - ALLOWED_RESPONSE_KEYS, - ConsecutiveRepeatedChatRolesError, - IncorrectMessageKeyQuantityError, - InvalidContentTypeError, - InvalidFileExtensionError, - InvalidLastChatMessageRoleError, - InvalidPromptResponseKeysError, - InvalidPromptTypeError, - InvalidResponseTypeError, - InvalidRoleError, - MisconfiguredHfDatasetError, - NotEnoughChatDataError, - UnableToProcessPromptResponseError, - UnknownExampleTypeError) +from llmfoundry.utils.exceptions import ( + ALLOWED_MESSAGES_KEYS, + ALLOWED_PROMPT_KEYS, + ALLOWED_RESPONSE_KEYS, + ConsecutiveRepeatedChatRolesError, + IncorrectMessageKeyQuantityError, + InvalidContentTypeError, + InvalidFileExtensionError, + InvalidLastChatMessageRoleError, + InvalidPromptResponseKeysError, + InvalidPromptTypeError, + InvalidResponseTypeError, + InvalidRoleError, + MisconfiguredHfDatasetError, + NotEnoughChatDataError, + UnableToProcessPromptResponseError, + UnknownExampleTypeError, +) # yapf: enable from llmfoundry.utils.logging_utils import SpecificWarningFilter @@ -86,8 +100,14 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: _ALLOWED_ROLES = {'user', 'assistant', 'system', 'tool'} _ALLOWED_LAST_MESSAGE_ROLES = {'assistant'} DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath( - os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir, - '.downloaded_finetuning')) + os.path.join( + os.path.realpath(__file__), + os.pardir, + os.pardir, + os.pardir, + '.downloaded_finetuning', + ), +) SUPPORTED_EXTENSIONS = ['.csv', '.json', '.jsonl', '.parquet'] PromptResponseDict = Mapping[str, str] @@ -111,14 +131,20 @@ def _get_example_type(example: Example) -> ExampleType: """ if not isinstance(example, Mapping): raise TypeError( - f'Expected example to be a Mapping, but found {type(example)}') - if (len(example.keys()) == 1 and - any(allowed_message_key in example - for allowed_message_key in ALLOWED_MESSAGES_KEYS)): + f'Expected example to be a Mapping, but found {type(example)}', + ) + if ( + len(example.keys()) == 1 and any( + allowed_message_key in example + for allowed_message_key in ALLOWED_MESSAGES_KEYS + ) + ): return 'chat' - elif (len(example.keys()) == 2 and - any(p in example for p in ALLOWED_PROMPT_KEYS) and - any(r in example for r in ALLOWED_RESPONSE_KEYS)): + elif ( + len(example.keys()) == 2 and + any(p in example for p in ALLOWED_PROMPT_KEYS) and + any(r in example for r in ALLOWED_RESPONSE_KEYS) + ): return 'prompt_response' else: raise UnknownExampleTypeError(example) @@ -139,7 +165,7 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]): if not isinstance(dictionary, Mapping): raise TypeError( - f'Expected dictionary to be a mapping, but found {type(dictionary)}' + f'Expected dictionary to be a mapping, but found {type(dictionary)}', ) desired_keys = allowed_keys.intersection(dictionary.keys()) return list(desired_keys)[0] @@ -148,11 +174,13 @@ def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]): def _validate_chat_formatted_example(example: ChatFormattedDict): if not isinstance(example, Mapping): raise TypeError( - f'Expected example to be a mapping, but found {type(example)}') + f'Expected example to be a mapping, but found {type(example)}', + ) messages = example[_get_key(example, ALLOWED_MESSAGES_KEYS)] if not isinstance(messages, List): raise TypeError( - f'Expected messages to be an iterable, but found {type(messages)}') + f'Expected messages to be an iterable, but found {type(messages)}', + ) if len(messages) <= 1: raise NotEnoughChatDataError() @@ -160,13 +188,17 @@ def _validate_chat_formatted_example(example: ChatFormattedDict): role_key = _get_key(last_message, _ALLOWED_ROLE_KEYS) last_role = last_message[role_key] if last_role not in _ALLOWED_LAST_MESSAGE_ROLES: - raise InvalidLastChatMessageRoleError(last_role, - _ALLOWED_LAST_MESSAGE_ROLES) + raise InvalidLastChatMessageRoleError( + last_role, + _ALLOWED_LAST_MESSAGE_ROLES, + ) last_message_role = None for message in messages: role_key, content_key = _get_key(message, _ALLOWED_ROLE_KEYS), _get_key( - message, _ALLOWED_CONTENT_KEYS) + message, + _ALLOWED_CONTENT_KEYS, + ) if len(message.keys()) != 2: raise IncorrectMessageKeyQuantityError(list(message.keys())) if message[role_key] not in _ALLOWED_ROLES: @@ -174,14 +206,15 @@ def _validate_chat_formatted_example(example: ChatFormattedDict): if not isinstance(message[content_key], str): raise InvalidContentTypeError(type(message[content_key])) if last_message_role is not None and last_message_role == message[ - role_key]: + role_key]: raise ConsecutiveRepeatedChatRolesError(last_message_role) last_message_role = message[role_key] def _slice_chat_formatted_example( - example: ChatFormattedDict, - tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, str]]: + example: ChatFormattedDict, + tokenizer: PreTrainedTokenizerBase, +) -> List[Tuple[str, str]]: """Slices chat example into a list of templated prompt, response turns. Note: Assistant messages mark the end of chat turns. So there are as many turns as there are @@ -203,31 +236,39 @@ def _slice_chat_formatted_example( last_message = messages[-1] if last_message['role'] != 'assistant': - raise InvalidLastChatMessageRoleError(last_message['role'], - set(['assistant'])) + raise InvalidLastChatMessageRoleError( + last_message['role'], + set('assistant'), + ) def slice_out_last_turn( - messages_through_current_turn: List[Dict[str, str]], - conversation_through_previous_turn: str) -> Tuple[str, str]: + messages_through_current_turn: List[Dict[str, str]], + conversation_through_previous_turn: str, + ) -> Tuple[str, str]: full_conversation = tokenizer.apply_chat_template( - messages_through_current_turn, tokenize=False) + messages_through_current_turn, + tokenize=False, + ) prompt_with_history = tokenizer.apply_chat_template( messages_through_current_turn[:-1], tokenize=False, - add_generation_prompt=True) + add_generation_prompt=True, + ) if conversation_through_previous_turn != full_conversation[:len( - conversation_through_previous_turn)]: + conversation_through_previous_turn, + )]: raise ValueError( - f'The full conversation must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {full_conversation=}' + f'The full conversation must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {full_conversation=}', ) if conversation_through_previous_turn != prompt_with_history[:len( - conversation_through_previous_turn)]: + conversation_through_previous_turn, + )]: raise ValueError( - f'The prompt_with_history must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}' + f'The prompt_with_history must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}', ) if prompt_with_history != full_conversation[:len(prompt_with_history)]: raise ValueError( - f'prompt_with_history must be the first part of the full conversation. {prompt_with_history=}, {full_conversation=}' + f'prompt_with_history must be the first part of the full conversation. {prompt_with_history=}, {full_conversation=}', ) prompt = prompt_with_history[len(conversation_through_previous_turn):] response = full_conversation[len(prompt_with_history):] @@ -238,7 +279,9 @@ def slice_out_last_turn( for idx, message in enumerate(messages): if message['role'] == 'assistant': prompt, response = slice_out_last_turn( - messages[:idx + 1], conversation_through_previous_turn) + messages[:idx + 1], + conversation_through_previous_turn, + ) templated_prompt_response_turns.append((prompt, response)) conversation_through_previous_turn += prompt conversation_through_previous_turn += response @@ -246,8 +289,11 @@ def slice_out_last_turn( return templated_prompt_response_turns -def _tokenize_with_bos_removal(tokenizer: PreTrainedTokenizerBase, text: str, - text_target: str) -> Dict[str, List[int]]: +def _tokenize_with_bos_removal( + tokenizer: PreTrainedTokenizerBase, + text: str, + text_target: str, +) -> Dict[str, List[int]]: """Tokenizes the prompt and response using the provided tokenizer. Args: @@ -258,23 +304,26 @@ def _tokenize_with_bos_removal(tokenizer: PreTrainedTokenizerBase, text: str, Returns: Dict[str, List[int]]: The tokenized text and text_target. """ - tokenized_sample = tokenizer(text=text, - text_target=text_target, - padding=False, - truncation=False) + tokenized_sample = tokenizer( + text=text, + text_target=text_target, + padding=False, + truncation=False, + ) # Remove the BOS token from the start of the labels if it was automatically added if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token: if tokenizer.bos_token_id is not None and tokenized_sample['labels'][ - 0] == tokenizer.bos_token_id: + 0] == tokenizer.bos_token_id: tokenized_sample['labels'] = tokenized_sample['labels'][1:] return tokenized_sample def _tokenize_chat_formatted_example( - example: ChatFormattedDict, - tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: + example: ChatFormattedDict, + tokenizer: PreTrainedTokenizerBase, +) -> TokenizedExample: """Tokenizes a chat-formatted example using the provided tokenizer. Args: @@ -292,20 +341,22 @@ def _tokenize_chat_formatted_example( # be able to assume that none of the tokens are pad tokens. return { 'turns': [ - tokenizer(text=prompt, - text_target=response, - add_special_tokens=False, - padding=False, - truncation=False) - for prompt, response in _slice_chat_formatted_example( - example, tokenizer) - ] + tokenizer( + text=prompt, + text_target=response, + add_special_tokens=False, + padding=False, + truncation=False, + ) for prompt, response in + _slice_chat_formatted_example(example, tokenizer) + ], } def _tokenize_prompt_response_formatted_example( - example: PromptResponseDict, - tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: + example: PromptResponseDict, + tokenizer: PreTrainedTokenizerBase, +) -> TokenizedExample: """Tokenize a formatted example and validate expected keys.""" example_keys = set(example.keys()) prompt_keys = example_keys.intersection(ALLOWED_PROMPT_KEYS) @@ -333,14 +384,15 @@ def _tokenize_prompt_response_formatted_example( tokenizer=tokenizer, text=prompt, text_target=response, - ) - ] + ), + ], } def tokenize_formatted_example( - example: Example, - tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: + example: Example, + tokenizer: PreTrainedTokenizerBase, +) -> TokenizedExample: """Tokenizes a formatted example using the provided tokenizer. Args: @@ -360,16 +412,24 @@ def tokenize_formatted_example( return _tokenize_chat_formatted_example(chat_example, tokenizer) elif example_format == 'prompt_response': prompt_response_example: PromptResponseDict = cast( - PromptResponseDict, example) + PromptResponseDict, + example, + ) return _tokenize_prompt_response_formatted_example( - prompt_response_example, tokenizer) + prompt_response_example, + tokenizer, + ) else: raise NotImplementedError -def is_valid_ift_example(max_seq_len: int, target_prompts: str, - target_responses: str, decoder_only_format: bool, - example: TokenizedExample) -> bool: +def is_valid_ift_example( + max_seq_len: int, + target_prompts: str, + target_responses: str, + decoder_only_format: bool, + example: TokenizedExample, +) -> bool: """Check if the example is a valid ift example. This function confirms that none of the ``input_ids`` and ``labels`` fields @@ -412,7 +472,8 @@ def is_valid_ift_example(max_seq_len: int, target_prompts: str, else: input_ids, labels = stitch_turns_encoder_decoder( - example_turns=example['turns'],) + example_turns=example['turns'], + ) input_ids = input_ids[:max_seq_len] labels = labels[:max_seq_len] @@ -425,14 +486,18 @@ def is_valid_ift_example(max_seq_len: int, target_prompts: str, return True -def _stream_remote_local_validate(remote: Optional[str], local: Optional[str], - split: Optional[str]): +def _stream_remote_local_validate( + remote: Optional[str], + local: Optional[str], + split: Optional[str], +): if remote is None or (local == remote): if local is not None and os.path.isdir(local): contents = set(os.listdir(local)) if split is not None and split not in contents: raise ValueError( - f'Local directory {local} does not contain split {split}') + f'Local directory {local} does not contain split {split}', + ) class StreamingFinetuningDataset(StreamingDataset): @@ -498,45 +563,50 @@ class StreamingFinetuningDataset(StreamingDataset): devices need to see the same partition of the dataset. Defaults to ``None``. """ - def __init__(self, - tokenizer: PreTrainedTokenizerBase, - streams: Optional[Sequence[Stream]] = None, - local: Optional[str] = None, - remote: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[Union[int, str]] = None, - predownload: Optional[int] = None, - cache_limit: Optional[Union[int, str]] = None, - partition_algo: str = 'relaxed', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1e', - shuffle_seed: int = 9176, - shuffle_block_size: Optional[int] = None, - sampling_method: str = 'balanced', - sampling_granularity: int = 1, - batching_method: str = 'random', - max_seq_len: int = 2048, - allow_unsafe_types: bool = False, - replication: Optional[int] = None, - **kwargs: Any): + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + streams: Optional[Sequence[Stream]] = None, + local: Optional[str] = None, + remote: Optional[str] = None, + split: Optional[str] = None, + download_retry: int = 2, + download_timeout: float = 60, + validate_hash: Optional[str] = None, + keep_zip: bool = False, + epoch_size: Optional[Union[int, str]] = None, + predownload: Optional[int] = None, + cache_limit: Optional[Union[int, str]] = None, + partition_algo: str = 'relaxed', + num_canonical_nodes: Optional[int] = None, + batch_size: Optional[int] = None, + shuffle: bool = False, + shuffle_algo: str = 'py1e', + shuffle_seed: int = 9176, + shuffle_block_size: Optional[int] = None, + sampling_method: str = 'balanced', + sampling_granularity: int = 1, + batching_method: str = 'random', + max_seq_len: int = 2048, + allow_unsafe_types: bool = False, + replication: Optional[int] = None, + **kwargs: Any, + ): if len(kwargs) > 0: raise ValueError( - f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}' + f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}', ) if streams is None: _stream_remote_local_validate(remote, local, split) else: for stream in streams: - _stream_remote_local_validate(stream.remote, stream.local, - split) + _stream_remote_local_validate( + stream.remote, + stream.local, + split, + ) super().__init__( streams=streams, @@ -578,18 +648,20 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: if isinstance(sample['input_ids'], bytes): sample['input_ids'] = np.frombuffer( sample['input_ids'], - dtype=np.int64)[:self.max_seq_len].tolist().copy() + dtype=np.int64, + )[:self.max_seq_len].tolist().copy() sample['labels'] = np.frombuffer( sample['labels'], - dtype=np.int64)[:self.max_seq_len].tolist().copy() + dtype=np.int64, + )[:self.max_seq_len].tolist().copy() elif isinstance(sample['input_ids'], np.ndarray): - sample['input_ids'] = sample[ - 'input_ids'][:self.max_seq_len].tolist().copy() + sample['input_ids'] = sample['input_ids'][:self.max_seq_len + ].tolist().copy() sample['labels'] = sample['labels'][:self.max_seq_len].tolist( ).copy() else: raise ValueError( - f'Expect input_ids to be bytes or numpy.ndarray type, but got {type(sample["input_ids"])}' + f'Expect input_ids to be bytes or numpy.ndarray type, but got {type(sample["input_ids"])}', ) # Convert to latest format by wrapping sample as a "turn" return {'turns': [sample]} @@ -607,7 +679,7 @@ def register(self, *names: str) -> Callable[[Callable], Callable]: def _register_func(name: str, func: Callable) -> None: if name in self._task_preprocessing_registry: raise ValueError( - f'A tokenization function has already been registered with {name=}.' + f'A tokenization function has already been registered with {name=}.', ) self._task_preprocessing_registry[name] = func return @@ -624,8 +696,9 @@ def print_registered_tasks(self) -> None: log.info('\n'.join(tasks)) def get_preprocessing_fn_from_dict( - self, mapping: Dict[str, - str]) -> Callable[[Dict[str, Any]], Example]: + self, + mapping: Dict[str, str], + ) -> Callable[[Dict[str, Any]], Example]: """Get a preprocessing function from a dictionary. The dictionary maps column names in the dataset to "prompt" and "response". @@ -652,7 +725,7 @@ def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]: raise InvalidPromptResponseKeysError(mapping, example) return { 'prompt': example[mapping['prompt']], - 'response': example[mapping['response']] + 'response': example[mapping['response']], } return _preprocessor @@ -660,7 +733,7 @@ def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]: def get_preprocessing_fn_from_str( self, preprocessor: Optional[str], - dataset_name: Optional[str] = None + dataset_name: Optional[str] = None, ) -> Optional[Callable[[Dict[str, Any]], Example]]: """Get a preprocessing function from a string. @@ -681,7 +754,7 @@ def get_preprocessing_fn_from_str( return None if dataset_name in self._task_preprocessing_registry: log.info( - f'Re-formatting dataset with "{dataset_name}" preprocessing function.' + f'Re-formatting dataset with "{dataset_name}" preprocessing function.', ) return self._task_preprocessing_registry[dataset_name] else: @@ -692,7 +765,7 @@ def get_preprocessing_fn_from_str( return None if preprocessor in self._task_preprocessing_registry: log.info( - f'Re-formatting dataset with "{preprocessor}" preprocessing function.' + f'Re-formatting dataset with "{preprocessor}" preprocessing function.', ) return self._task_preprocessing_registry[preprocessor] @@ -702,17 +775,23 @@ def get_preprocessing_fn_from_str( preprocessing_fn = getattr(module, function_name) except Exception as e: raise ValueError( - f'Failed to import preprocessing function from string = {preprocessor}.' + f'Failed to import preprocessing function from string = {preprocessor}.', ) from e return preprocessing_fn def build_from_hf( - self, dataset_name: str, split: str, safe_load: bool, max_seq_len: int, + self, + dataset_name: str, + split: str, + safe_load: bool, + max_seq_len: int, preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]], - tokenizer: PreTrainedTokenizerBase, target_prompts: str, - target_responses: str, decoder_only_format: bool, hf_kwargs: Dict[str, - Any] + tokenizer: PreTrainedTokenizerBase, + target_prompts: str, + target_responses: str, + decoder_only_format: bool, + hf_kwargs: Dict[str, Any], ) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]: """Load a HuggingFace Datasets, preprocess, and tokenize. @@ -738,9 +817,10 @@ def build_from_hf( pass hf_tokenization_logger = logging.getLogger( - 'transformers.tokenization_utils_base') + 'transformers.tokenization_utils_base', + ) sequence_length_warning_filter = SpecificWarningFilter( - 'Token indices sequence length is longer than the specified maximum sequence length' + 'Token indices sequence length is longer than the specified maximum sequence length', ) # We will trim examples later in the collate_fn, so we want to silence this warning from Hugging Face @@ -753,7 +833,9 @@ def build_from_hf( if not os.path.isdir(dataset_name): # dataset_name is not a local dir path, download if needed. local_dataset_dir = os.path.join( - DOWNLOADED_FT_DATASETS_DIRPATH, dataset_name) + DOWNLOADED_FT_DATASETS_DIRPATH, + dataset_name, + ) if _is_empty_or_nonexistent(dirpath=local_dataset_dir): # Safely load a dataset from HF Hub with restricted file types. @@ -766,10 +848,13 @@ def build_from_hf( token=hf_kwargs.get('token', None), revision=hf_kwargs.get('revision', None), local_dir_use_symlinks=False, - local_dir=local_dataset_dir) + local_dir=local_dataset_dir, + ) if _is_empty_or_nonexistent(dirpath=local_dataset_dir): raise InvalidFileExtensionError( - dataset_name, SUPPORTED_EXTENSIONS) + dataset_name, + SUPPORTED_EXTENSIONS, + ) # Set dataset_name to the downloaded location. dataset_name = local_dataset_dir @@ -781,19 +866,26 @@ def build_from_hf( f for _, _, files in os.walk(dataset_name) for f in files ] if not all( - Path(f).suffix in SUPPORTED_EXTENSIONS - for f in dataset_files): - raise InvalidFileExtensionError(dataset_name, - SUPPORTED_EXTENSIONS) - - dataset = hf_datasets.load_dataset(dataset_name, - split=split, - **hf_kwargs) + Path(f).suffix in SUPPORTED_EXTENSIONS + for f in dataset_files + ): + raise InvalidFileExtensionError( + dataset_name, + SUPPORTED_EXTENSIONS, + ) + + dataset = hf_datasets.load_dataset( + dataset_name, + split=split, + **hf_kwargs, + ) def dataset_mapper(example: Dict): if preprocessing_fn is not None: - return tokenize_formatted_example(preprocessing_fn(example), - tokenizer) + return tokenize_formatted_example( + preprocessing_fn(example), + tokenizer, + ) return tokenize_formatted_example(example, tokenizer) detected_cpu_count = os.cpu_count() or 1 @@ -810,8 +902,13 @@ def dataset_mapper(example: Dict): ) filtered_dataset = tokenized_dataset.filter( - partial(is_valid_ift_example, max_seq_len, target_prompts, - target_responses, decoder_only_format), + partial( + is_valid_ift_example, + max_seq_len, + target_prompts, + target_responses, + decoder_only_format, + ), num_proc=num_cpus_to_use, desc='Filtering out long prompts', ) @@ -821,7 +918,7 @@ def dataset_mapper(example: Dict): warnings.warn( f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, ' + - 'the prompt or response was empty, or the response was all padding tokens.' + 'the prompt or response was empty, or the response was all padding tokens.', ) except Exception as e: error = e @@ -840,8 +937,10 @@ def dataset_mapper(example: Dict): if isinstance(error, hf_exceptions.DatasetGenerationError): log.error('Huggingface DatasetGenerationError during data prep.') - raise MisconfiguredHfDatasetError(dataset_name=dataset_name, - split=split) + raise MisconfiguredHfDatasetError( + dataset_name=dataset_name, + split=split, + ) if error is not None: log.error('Error during data prep') raise error @@ -852,8 +951,11 @@ def dataset_mapper(example: Dict): assert filtered_dataset is not None return filtered_dataset - def build_from_streaming(self, *args: Any, - **kwargs: Any) -> StreamingFinetuningDataset: + def build_from_streaming( + self, + *args: Any, + **kwargs: Any, + ) -> StreamingFinetuningDataset: return StreamingFinetuningDataset(*args, **kwargs) @@ -906,8 +1008,9 @@ def muennighoff_tokenize_function(inp: Dict) -> PromptResponseDict: response: str = inp['targets'] # Put a space before the response if needed transitions = (' ', '\n', '\t') - if not (prompt.endswith(transitions) or - response.startswith(transitions)): + if not ( + prompt.endswith(transitions) or response.startswith(transitions) + ): response = ' ' + response except Exception as e: raise UnableToProcessPromptResponseError(inp) from e diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 3d525def47..0340114008 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -23,13 +23,15 @@ class BinPackCollator: """Utility collator for packing to reduce padding.""" - def __init__(self, - collator: Callable, - target_batch_size: int, - max_seq_len: int, - pad_token_id: int, - padding_side: Literal['left', 'right'], - max_leftover_bins_to_keep: Optional[int] = None): + def __init__( + self, + collator: Callable, + target_batch_size: int, + max_seq_len: int, + pad_token_id: int, + padding_side: Literal['left', 'right'], + max_leftover_bins_to_keep: Optional[int] = None, + ): self.base_collator = collator self.out_size = int(target_batch_size) self.max_seq_len = int(max_seq_len) @@ -45,7 +47,8 @@ def __init__(self, if max_leftover_bins_to_keep is not None and max_leftover_bins_to_keep < 0: raise ValueError( - f'{max_leftover_bins_to_keep=} must be >=0 or None.') + f'{max_leftover_bins_to_keep=} must be >=0 or None.', + ) self.max_leftover_bins_to_keep = max_leftover_bins_to_keep self.n_packed_tokens = 0 @@ -60,12 +63,14 @@ def waste(self) -> float: @property def efficiency(self) -> float: - return self.n_packed_tokens / (self.max_seq_len * - self.n_packed_examples) + return self.n_packed_tokens / ( + self.max_seq_len * self.n_packed_examples + ) def __call__( - self, - examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + self, + examples: List[Dict[str, torch.Tensor]], + ) -> Dict[str, torch.Tensor]: batch = self.base_collator(examples) return self.pack(batch) @@ -84,9 +89,11 @@ def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: sizes, trimmed_examples = _trim_batch(batch) return self._pack_trimmed_examples(trimmed_examples, sizes) - def _pack_trimmed_examples(self, trimmed_examples: List[Dict[str, - torch.Tensor]], - sizes: List[int]) -> Dict[str, torch.Tensor]: + def _pack_trimmed_examples( + self, + trimmed_examples: List[Dict[str, torch.Tensor]], + sizes: List[int], + ) -> Dict[str, torch.Tensor]: """Packs trimmed examples into fixed-size bins and repads them. Args: @@ -110,15 +117,17 @@ def _pack_trimmed_examples(self, trimmed_examples: List[Dict[str, self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep] # Re-pad to max_seq_len and batch - batch = _repad(packed_examples, - max_seq_len=self.max_seq_len, - pad_token_id=self.pad_token_id, - padding_side=self.padding_side) + batch = _repad( + packed_examples, + max_seq_len=self.max_seq_len, + pad_token_id=self.pad_token_id, + padding_side=self.padding_side, + ) return batch def _trim_batch( - batch: Dict[str, torch.Tensor] + batch: Dict[str, torch.Tensor], ) -> Tuple[List[int], List[Dict[str, torch.Tensor]]]: """Trims padding off all examples in batch. @@ -150,8 +159,9 @@ def _extract_trim_batch_idx(batch: Dict[str, torch.Tensor], def _combine_in_place( - example: Dict[str, torch.Tensor], - add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + example: Dict[str, torch.Tensor], + add_on: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: if 'labels' in add_on: # Prevents the last token in example from being trained to # predict the first token in add_on, which would make no sense. @@ -159,30 +169,35 @@ def _combine_in_place( for k in example.keys(): if k == 'sequence_id': - example[k] = torch.cat( - [example[k], add_on[k] + 1 + torch.max(example[k])]) + example[k] = torch.cat([ + example[k], + add_on[k] + 1 + torch.max(example[k]), + ]) else: example[k] = torch.cat([example[k], add_on[k]]) return example def _first_fit_bin_packing( - sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int, - max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]] + sizes: List[int], + examples: List[Dict[str, torch.Tensor]], + num_bins: int, + max_bin_size: int, + existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]], ) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[ - str, torch.Tensor]]]]: + str, torch.Tensor]]]]: # Will contain tuples (bin_size_size, packed_example) bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins]) - sizes_and_examples = [ - (size, example) for size, example in zip(sizes, examples) - ] - sorted_sizes_and_examples = sorted(sizes_and_examples, - key=lambda x: x[0], - reverse=True) + sizes_and_examples = list(zip(sizes, examples)) + sorted_sizes_and_examples = sorted( + sizes_and_examples, + key=lambda x: x[0], + reverse=True, + ) required_num_examples = max(0, num_bins - len(bins)) num_examples = len(sizes) @@ -196,7 +211,7 @@ def _first_fit_bin_packing( total_example_sizes = sum(sizes) if total_new_bin_sizes != total_example_sizes: raise AssertionError( - f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.' + f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.', ) sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True) @@ -211,7 +226,8 @@ def _first_fit_bin_packing( # - the total size of all new examples # - leftover bins return packed_examples[:num_bins], sum( - bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:] + bin_sizes[:num_bins], + ), sum(sizes), sorted_bins[num_bins:] # Go through each item from longest to shortest. # Note: all items will either go into an existing or new bin. @@ -244,7 +260,7 @@ def _first_fit_bin_packing( total_example_sizes = sum(sizes) if total_new_bin_sizes != total_example_sizes: raise AssertionError( - f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.' + f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.', ) sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True) @@ -259,11 +275,16 @@ def _first_fit_bin_packing( # - the total size of all new examples # - leftover bins return packed_examples[:num_bins], sum( - bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:] + bin_sizes[:num_bins], + ), sum(sizes), sorted_bins[num_bins:] -def _repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, - pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]: +def _repad( + packed_examples: List[Dict[str, torch.Tensor]], + max_seq_len: int, + pad_token_id: int, + padding_side: str, +) -> Dict[str, torch.Tensor]: def pad_tensor(tensor: torch.Tensor, pad_value: int): if len(tensor) == max_seq_len: @@ -296,10 +317,12 @@ def pad_tensor(tensor: torch.Tensor, pad_value: int): return batch -def auto_packing_ratio(dataloader_cfg: DictConfig, - tokenizer: PreTrainedTokenizerBase, - device_batch_size: int, - num_packing_ratios: int = 20) -> float: +def auto_packing_ratio( + dataloader_cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, + device_batch_size: int, + num_packing_ratios: int = 20, +) -> float: """Find a packing ratio that minimizes padding with zero waste. By packing examples, we can increase training efficiency, training on more data with less batches. @@ -336,9 +359,14 @@ def auto_packing_ratio(dataloader_cfg: DictConfig, min_ratio = 1 max_ratio = max_seq_len / 100 - profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio, - max_ratio, num_packing_ratios, - device_batch_size) + profiling_results = profile_packing( + dataloader_cfg, + tokenizer, + min_ratio, + max_ratio, + num_packing_ratios, + device_batch_size, + ) # Obtain the maximum packing_ratio/minimum padding that has no waste. # profiling_results are sorted from smallest to largest packing_ratio. @@ -352,7 +380,8 @@ def auto_packing_ratio(dataloader_cfg: DictConfig, if dist.is_available() and dist.is_initialized(): device = get_device(None) packing_ratio_tensor = device.tensor_to_device( - torch.tensor(packing_ratio)) + torch.tensor(packing_ratio), + ) dist.all_reduce(packing_ratio_tensor, reduce_operation='MIN') packing_ratio = packing_ratio_tensor.item() @@ -363,9 +392,12 @@ def auto_packing_ratio(dataloader_cfg: DictConfig, def profile_packing( - dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - min_ratio: float, max_ratio: float, num_packing_ratios: int, - device_batch_size: int + dataloader_cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, + min_ratio: float, + max_ratio: float, + num_packing_ratios: int, + device_batch_size: int, ) -> Iterable[Tuple[float, Optional[float], Optional[float]]]: """Generator function that profiles example packing across packing ratios. @@ -385,8 +417,10 @@ def profile_packing( from llmfoundry.data.dataloader import build_dataloader max_seq_len = dataloader_cfg.dataset.get('max_seq_len') - max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', - None) + max_leftovers_to_keep = dataloader_cfg.dataset.get( + 'max_leftovers_to_keep', + None, + ) # Turn off packing for the dataloader (we want raw, pre-packed examples) dataloader_cfg = copy.deepcopy(dataloader_cfg) @@ -413,10 +447,12 @@ def profile_packing( # Determine the packing_ratio values we'll try packing_ratios, raw_batch_sizes = [], [] - for packing_ratio in np.linspace(min_ratio, - max_ratio, - num_packing_ratios, - endpoint=True): + for packing_ratio in np.linspace( + min_ratio, + max_ratio, + num_packing_ratios, + endpoint=True, + ): packing_ratio = np.round(10 * packing_ratio) / 10 raw_batch_size = int(packing_ratio * device_batch_size) if raw_batch_size not in raw_batch_sizes: @@ -425,8 +461,11 @@ def profile_packing( n_profile_examples = max(raw_batch_sizes) * 100 - train_dataspec = build_dataloader(dataloader_cfg, tokenizer, - n_profile_examples) + train_dataspec = build_dataloader( + dataloader_cfg, + tokenizer, + n_profile_examples, + ) train_dataloader = train_dataspec.dataloader # Get a bunch of raw examples @@ -446,19 +485,22 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]: max_seq_len=max_seq_len, pad_token_id=0, # <-- Doesn't need to be correct for profiling padding_side='left', # <-- Doesn't need to be correct for profiling - max_leftover_bins_to_keep=max_leftovers_to_keep) + max_leftover_bins_to_keep=max_leftovers_to_keep, + ) # Simulate feeding the packing collator a bunch of data for idx in range(0, len(trimmed_examples_copy), raw_batch_size): batch = trimmed_examples_copy[idx:idx + raw_batch_size] if len(batch) < device_batch_size: continue - packer._pack_trimmed_examples(batch, - sizes[idx:idx + raw_batch_size]) + packer._pack_trimmed_examples( + batch, + sizes[idx:idx + raw_batch_size], + ) if packer.n_packed_examples == 0: log.debug( - 'No examples packed during profiling. Dataset is smaller than device batch size.' + 'No examples packed during profiling. Dataset is smaller than device batch size.', ) return None, None @@ -472,7 +514,7 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]: for i, (packing_ratio, raw_batch_size) in enumerate(zip(packing_ratios, raw_batch_sizes)): log.debug( - f'Progress [{i}/{total_packing_ratios}]: Profiling packing ratio {packing_ratio}' + f'Progress [{i}/{total_packing_ratios}]: Profiling packing ratio {packing_ratio}', ) padding, waste = profile(raw_batch_size) yield (packing_ratio, padding, waste) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index a59098323b..bbb5ae6a15 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -6,8 +6,17 @@ import logging import os from itertools import islice -from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, - Union, cast) +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Union, + cast, +) import numpy as np import torch @@ -98,37 +107,39 @@ class StreamingTextDataset(StreamingDataset): devices need to see the same partition of the dataset. Defaults to ``None``. """ - def __init__(self, - tokenizer: PreTrainedTokenizerBase, - max_seq_len: int, - streams: Optional[Sequence[Stream]] = None, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[Union[int, str]] = None, - predownload: Optional[int] = None, - cache_limit: Optional[Union[int, str]] = None, - partition_algo: str = 'relaxed', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1e', - shuffle_seed: int = 9176, - shuffle_block_size: Optional[int] = None, - sampling_method: str = 'balanced', - sampling_granularity: int = 1, - batching_method: str = 'random', - allow_unsafe_types: bool = False, - replication: Optional[int] = None, - **kwargs: Any): + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + max_seq_len: int, + streams: Optional[Sequence[Stream]] = None, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[str] = None, + download_retry: int = 2, + download_timeout: float = 60, + validate_hash: Optional[str] = None, + keep_zip: bool = False, + epoch_size: Optional[Union[int, str]] = None, + predownload: Optional[int] = None, + cache_limit: Optional[Union[int, str]] = None, + partition_algo: str = 'relaxed', + num_canonical_nodes: Optional[int] = None, + batch_size: Optional[int] = None, + shuffle: bool = False, + shuffle_algo: str = 'py1e', + shuffle_seed: int = 9176, + shuffle_block_size: Optional[int] = None, + sampling_method: str = 'balanced', + sampling_granularity: int = 1, + batching_method: str = 'random', + allow_unsafe_types: bool = False, + replication: Optional[int] = None, + **kwargs: Any, + ): if len(kwargs) > 0: raise ValueError( - f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}' + f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}', ) if local is not None and (remote is None or (local == remote)): @@ -136,7 +147,7 @@ def __init__(self, contents = set(os.listdir(local)) if split not in contents: raise ValueError( - f'local directory {local} does not contain split {split}' + f'local directory {local} does not contain split {split}', ) # TODO: discover where yamls are being converted incorrect, but temporary workaround @@ -177,18 +188,24 @@ def _tokenize(self, text_sample: Mapping) -> Dict[str, List[int]]: if self.tokenizer._pad_token is None: # Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs raise RuntimeError( - 'If tokenizing on-the-fly, tokenizer must have a pad_token_id') + 'If tokenizing on-the-fly, tokenizer must have a pad_token_id', + ) - return self.tokenizer(text_sample['text'], - truncation=True, - padding='max_length', - max_length=self.max_seq_len) + return self.tokenizer( + text_sample['text'], + truncation=True, + padding='max_length', + max_length=self.max_seq_len, + ) - def _read_binary_tokenized_sample(self, sample: Dict[str, - Any]) -> torch.Tensor: + def _read_binary_tokenized_sample( + self, + sample: Dict[str, Any], + ) -> torch.Tensor: return torch.from_numpy( np.frombuffer(sample['tokens'], - dtype=np.int64)[:self.max_seq_len].copy()) + dtype=np.int64)[:self.max_seq_len].copy(), + ) # How to process a sample def __getitem__(self, @@ -200,7 +217,7 @@ def __getitem__(self, token_sample = self._read_binary_tokenized_sample(sample) else: raise RuntimeError( - 'StreamingTextDataset needs samples to have a `text` or `tokens` column' + 'StreamingTextDataset needs samples to have a `text` or `tokens` column', ) return token_sample @@ -217,13 +234,13 @@ def __init__( self.base_collator = base_collator if (eos_token_id is None) and (bos_token_id is None): raise ValueError( - 'Must supply a value for either eos_token_id or bos_token_id, but got None for both.' + 'Must supply a value for either eos_token_id or bos_token_id, but got None for both.', ) if (eos_token_id is not None) and (bos_token_id is not None): raise ValueError( 'Cannot use *both* EOS and BOS tokens for detecting sequence boundaries. ' +\ 'Please supply `eos_token_id` if sequences end with an EOS token, or use ' +\ - '`bos_token_id` if sequences start with a BOS token.' + '`bos_token_id` if sequences start with a BOS token.', ) if eos_token_id is None: @@ -239,7 +256,9 @@ def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]: return batch def get_sequence_id_from_batch( - self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + self, + batch: Dict[str, torch.Tensor], + ) -> torch.Tensor: is_separator = torch.eq(batch['input_ids'], self.split_token_id) cumulative_sep = torch.cumsum(is_separator, dim=1).to(batch['input_ids'].dtype) @@ -258,7 +277,7 @@ def build_streams(dataset_cfg: DictConfig): streams = None if streams_dict is not None: streams = [] - for _, stream in streams_dict.items(): + for stream in streams_dict.values(): # stream is the streams kwargs # fwd all kwargs with **stream allows streaming to check args streams.append(Stream(**stream)) @@ -277,10 +296,12 @@ def build_text_dataloader( eos_token_id = cfg.dataset.pop('eos_token_id', None) bos_token_id = cfg.dataset.pop('bos_token_id', None) - if eos_token_id is None and bos_token_id is None and (hasattr( - tokenizer, 'eos_token_id') or hasattr(tokenizer, 'bos_token_id')): + if eos_token_id is None and bos_token_id is None and ( + hasattr(tokenizer, 'eos_token_id') or + hasattr(tokenizer, 'bos_token_id') + ): log.warning( - 'The user has not provided an eos_token_id or bos_token_id, but the tokenizer has an eos_token_id or a bos_token_id.' + 'The user has not provided an eos_token_id or bos_token_id, but the tokenizer has an eos_token_id or a bos_token_id.', ) tokenizer_eos_token_id = getattr(tokenizer, 'eos_token_id', None) @@ -291,7 +312,7 @@ def build_text_dataloader( else: raise ValueError( eos_mismatch_str + - ' To override this error, set the override_eos_token_id_mismatch_error flag to True in the dataset config section of the YAML.' + ' To override this error, set the override_eos_token_id_mismatch_error flag to True in the dataset config section of the YAML.', ) tokenizer_bos_token_id = getattr(tokenizer, 'bos_token_id', None) @@ -302,7 +323,7 @@ def build_text_dataloader( else: raise ValueError( bos_mismatch_str + - ' To override this error, set the override_bos_token_id_mismatch_error flag to True in the dataset config section of the YAML.' + ' To override this error, set the override_bos_token_id_mismatch_error flag to True in the dataset config section of the YAML.', ) streams = build_streams(cfg.dataset) @@ -318,14 +339,16 @@ def build_text_dataloader( collate_fn = transformers.DataCollatorForLanguageModeling( tokenizer=dataset.tokenizer, mlm=mlm_probability is not None, - mlm_probability=mlm_probability) + mlm_probability=mlm_probability, + ) if (eos_token_id is not None) or (bos_token_id is not None): # Note: Will raise an error if both are non-None collate_fn = ConcatenatedSequenceCollatorWrapper( base_collator=collate_fn, eos_token_id=eos_token_id, - bos_token_id=bos_token_id) + bos_token_id=bos_token_id, + ) dl = DataLoader( dataset, @@ -352,7 +375,8 @@ def build_text_dataloader( def get_tokens_per_batch_func( - decoder_only: bool = True) -> Callable[[Batch], int]: + decoder_only: bool = True, +) -> Callable[[Batch], int]: """Returns a callable that counts the number of tokens in a batch. Args: @@ -365,15 +389,16 @@ def get_tokens_per_batch_func( """ def get_num_samples_in_batch(batch: Batch) -> int: - if not isinstance(batch, Mapping) or ('attention_mask' not in batch and - 'input_ids' not in batch): + if not isinstance(batch, Mapping) or ( + 'attention_mask' not in batch and 'input_ids' not in batch + ): raise ValueError( - 'get_tokens_per_batch_func() requires a batch with an attention_mask key or an input_ids key' + 'get_tokens_per_batch_func() requires a batch with an attention_mask key or an input_ids key', ) if not decoder_only and 'decoder_attention_mask' not in batch: raise ValueError( - 'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key' + 'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key', ) # Count number of non padding tokens in batch @@ -386,7 +411,8 @@ def get_num_samples_in_batch(batch: Batch) -> int: decoder_input_ids_tokens = 0 if not decoder_only: decoder_input_ids_tokens = int( - torch.sum(batch['decoder_attention_mask']).item()) + torch.sum(batch['decoder_attention_mask']).item(), + ) return input_ids_tokens + decoder_input_ids_tokens @@ -401,33 +427,42 @@ def get_num_samples_in_batch(batch: Batch) -> int: from llmfoundry.utils.builders import build_tokenizer parser = argparse.ArgumentParser() - parser.add_argument('--tokenizer', - type=str, - default='EleutherAI/gpt-neox-20b', - help='the name of the tokenizer to use') - parser.add_argument('--local_path', - type=str, - required=True, - help='the path to the local copy of the dataset') + parser.add_argument( + '--tokenizer', + type=str, + default='EleutherAI/gpt-neox-20b', + help='the name of the tokenizer to use', + ) + parser.add_argument( + '--local_path', + type=str, + required=True, + help='the path to the local copy of the dataset', + ) parser.add_argument( '--remote_path', type=str, default=None, - help='the path to the remote copy to stream from (optional)') - parser.add_argument('--split', - type=str, - default='val', - help='which split of the dataset to use') - parser.add_argument('--max_seq_len', - type=int, - default=32, - help='max sequence length to test') + help='the path to the remote copy to stream from (optional)', + ) + parser.add_argument( + '--split', + type=str, + default='val', + help='which split of the dataset to use', + ) + parser.add_argument( + '--max_seq_len', + type=int, + default=32, + help='max sequence length to test', + ) args = parser.parse_args() if args.remote_path is not None: print( - f'Reading {args.split} split from {args.local_path} <- streamed from <- {args.remote_path}' + f'Reading {args.split} split from {args.local_path} <- streamed from <- {args.remote_path}', ) else: print(f'Reading {args.split} split from {args.local_path}') diff --git a/llmfoundry/eval/__init__.py b/llmfoundry/eval/__init__.py index f2425296b6..54f8217920 100644 --- a/llmfoundry/eval/__init__.py +++ b/llmfoundry/eval/__init__.py @@ -2,16 +2,23 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.eval.datasets.in_context_learning_evaluation import ( - InContextLearningCodeEvalDataset, InContextLearningDataset, + InContextLearningCodeEvalDataset, + InContextLearningDataset, InContextLearningGenerationTaskWithAnswersDataset, - InContextLearningLMTaskDataset, InContextLearningMultipleChoiceTaskDataset, - InContextLearningSchemaTaskDataset, get_icl_task_dataloader) + InContextLearningLMTaskDataset, + InContextLearningMultipleChoiceTaskDataset, + InContextLearningSchemaTaskDataset, + get_icl_task_dataloader, +) from llmfoundry.eval.metrics.nlp import ( InContextLearningCodeEvalAccuracy, - InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningGenerationExactMatchAccuracy, + InContextLearningLMAccuracy, InContextLearningLMExpectedCalibrationError, - InContextLearningMCExpectedCalibrationError, InContextLearningMetric, - InContextLearningMultipleChoiceAccuracy) + InContextLearningMCExpectedCalibrationError, + InContextLearningMetric, + InContextLearningMultipleChoiceAccuracy, +) __all__ = [ 'InContextLearningDataset', diff --git a/llmfoundry/eval/datasets/__init__.py b/llmfoundry/eval/datasets/__init__.py index 0be9882b0c..517dc3e1f3 100644 --- a/llmfoundry/eval/datasets/__init__.py +++ b/llmfoundry/eval/datasets/__init__.py @@ -4,17 +4,25 @@ """Natively supported in-context learning evaluation datasets.""" from llmfoundry.eval.datasets.in_context_learning_evaluation import ( - InContextLearningCodeEvalDataset, InContextLearningDataset, + InContextLearningCodeEvalDataset, + InContextLearningDataset, InContextLearningGenerationTaskWithAnswersDataset, - InContextLearningLMTaskDataset, InContextLearningMultipleChoiceTaskDataset, - InContextLearningSchemaTaskDataset, get_icl_task_dataloader) - -# isort: off + InContextLearningLMTaskDataset, + InContextLearningMultipleChoiceTaskDataset, + InContextLearningSchemaTaskDataset, + get_icl_task_dataloader, +) from llmfoundry.eval.datasets.utils import ( - MultiTokenEOSCriteria, convert_tokens_to_tensors, get_continuation_span, - get_fewshot_sample_idxs, make_padded_input, stop_sequences_criteria, - strip_data, tokenizer_needs_prefix_space, trim_context) -# isort: on + MultiTokenEOSCriteria, + convert_tokens_to_tensors, + get_continuation_span, + get_fewshot_sample_idxs, + make_padded_input, + stop_sequences_criteria, + strip_data, + tokenizer_needs_prefix_space, + trim_context, +) __all__ = [ 'InContextLearningDataset', diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py index 447855f953..591308272c 100644 --- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -21,12 +21,15 @@ from datasets import IterableDataset, load_dataset from torch.utils.data import DataLoader, Dataset -from llmfoundry.eval.datasets.utils import (convert_tokens_to_tensors, - get_continuation_span, - get_fewshot_sample_idxs, - make_padded_input, strip_data, - tokenizer_needs_prefix_space, - trim_context) +from llmfoundry.eval.datasets.utils import ( + convert_tokens_to_tensors, + get_continuation_span, + get_fewshot_sample_idxs, + make_padded_input, + strip_data, + tokenizer_needs_prefix_space, + trim_context, +) from llmfoundry.utils.warnings import VersionedDeprecationWarning log = logging.getLogger(__name__) @@ -159,10 +162,12 @@ def __init__( self.tensor_keys = tensor_keys hf_loading_vars = hf_loading_vars or {} - self.dataset: HFDataset = self.read_dataset(dataset_uri, - destination_path, - hf_loading_vars, - hf_parsing_map) + self.dataset: HFDataset = self.read_dataset( + dataset_uri, + destination_path, + hf_loading_vars, + hf_parsing_map, + ) self.strip_data = strip_dataset if self.strip_data: self.dataset = self.dataset.map(strip_data) @@ -204,11 +209,12 @@ def update_generation_kwargs(self, generation_kwargs: Dict) -> None: self.base_batch['generation_kwargs'].update(generation_kwargs) def read_dataset( - self, - dataset_uri: str, - destination_path: str, - hf_loading_vars: Optional[Dict[str, Any]] = None, - hf_parsing_map: Optional[Dict[str, Any]] = None) -> 'HFDataset': + self, + dataset_uri: str, + destination_path: str, + hf_loading_vars: Optional[Dict[str, Any]] = None, + hf_parsing_map: Optional[Dict[str, Any]] = None, + ) -> 'HFDataset': """Reads a dataset and handles parsing it from HuggingFace. Args: @@ -238,16 +244,20 @@ def read_dataset( ) } assert isinstance(dataset, HFDataset) - dataset = dataset.map(dataset_parsing_func, - remove_columns=dataset.column_names) + dataset = dataset.map( + dataset_parsing_func, + remove_columns=dataset.column_names, + ) else: with dist.local_rank_zero_download_and_wait(destination_path): if dist.get_local_rank() == 0: get_file(dataset_uri, destination_path, overwrite=True) - dataset = load_dataset('json', - data_files=destination_path, - split='train', - streaming=False) + dataset = load_dataset( + 'json', + data_files=destination_path, + split='train', + streaming=False, + ) assert isinstance(dataset, HFDataset) return dataset @@ -293,10 +303,12 @@ def _generate_few_shot_prompt( return few_shot_text - def construct_context(self, - example: Dict, - preceding_text: str = '', - add_answer: bool = False) -> str: + def construct_context( + self, + example: Dict, + preceding_text: str = '', + add_answer: bool = False, + ) -> str: """Takes an example and constructs a context, i.e. the input the model. reads for this example. Optionally adds the correct answer (for fewshot @@ -320,9 +332,11 @@ def construct_context(self, ctxt = f'{ctxt}{self.get_answer_from_example(example, in_context=add_answer)}' return ctxt - def get_answer_from_example(self, - example: Dict[str, Any], - in_context: bool = False) -> str: + def get_answer_from_example( + self, + example: Dict[str, Any], + in_context: bool = False, + ) -> str: """Returns the answer from the example. Args: @@ -350,13 +364,19 @@ def _fix_eos_on_preamble(self, input_ids: List[int]) -> List[int]: Returns: input_ids: The tokenized input conditionally edited """ - if (self.tokenizer.eos_token_id is not None and len(input_ids) > 1 and - input_ids[-1] == self.tokenizer.eos_token_id): + if ( + self.tokenizer.eos_token_id is not None and len(input_ids) > 1 and + input_ids[-1] == self.tokenizer.eos_token_id + ): input_ids = input_ids[:-1] return input_ids - def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, - example: Dict) -> Dict[str, Any]: + def tokenize_example( + self, + prompt_and_fewshot: str, + ctxt: str, + example: Dict, + ) -> Dict[str, Any]: """Runs text through the tokenizer and handle special cases. Args: @@ -377,7 +397,9 @@ def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, ctxt = ctxt.rstrip() # Never add special tokens to context tokenized_context = self.tokenizer( - ctxt, add_special_tokens=False)['input_ids'] + ctxt, + add_special_tokens=False, + )['input_ids'] assert isinstance(preamble, list) assert isinstance(tokenized_context, list) @@ -387,18 +409,26 @@ def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, # Never add special tokens to answer tokenized_answer = self.tokenizer( self.get_answer_from_example(example), - add_special_tokens=False)['input_ids'] + add_special_tokens=False, + )['input_ids'] assert isinstance(tokenized_answer, list) - trimmed_context = trim_context(tokenized_context, tokenized_answer, - self.padding_size) + trimmed_context = trim_context( + tokenized_context, + tokenized_answer, + self.padding_size, + ) assert isinstance(trimmed_context, list) continuation_indices = get_continuation_span( - trimmed_context, tokenized_answer) - padded_context = make_padded_input(trimmed_context, - tokenized_answer, - self.padding_size, - self.pad_tok_id, - self.padding_side) + trimmed_context, + tokenized_answer, + ) + padded_context = make_padded_input( + trimmed_context, + tokenized_answer, + self.padding_size, + self.pad_tok_id, + self.padding_side, + ) tokenized_example[self.context_key] = padded_context tokenized_example[self.answer_key] = tokenized_answer @@ -411,14 +441,17 @@ def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, self.padding_size, ) assert isinstance(trimmed_context, list) - padded_context = make_padded_input(trimmed_context, [], - self.padding_size, - self.pad_tok_id, - self.padding_side) + padded_context = make_padded_input( + trimmed_context, + [], + self.padding_size, + self.pad_tok_id, + self.padding_side, + ) tokenized_example[self.context_key] = padded_context - tokenized_example[self.answer_key] = self.get_answer_from_example( - example) + tokenized_example[self.answer_key + ] = self.get_answer_from_example(example) return tokenized_example @@ -448,12 +481,21 @@ def _prep_example( Dict: Contains a dictionary with the tokenized data """ prompt_and_fewshot = self._generate_few_shot_prompt( - num_fewshot, example_idx, prompt_string, fewshot_rng) - ctxt = self.construct_context(example, - prompt_and_fewshot, - add_answer=False) - tokenized_example = self.tokenize_example(prompt_and_fewshot, ctxt, - example) + num_fewshot, + example_idx, + prompt_string, + fewshot_rng, + ) + ctxt = self.construct_context( + example, + prompt_and_fewshot, + add_answer=False, + ) + tokenized_example = self.tokenize_example( + prompt_and_fewshot, + ctxt, + example, + ) return tokenized_example def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: @@ -473,7 +515,8 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: batch[batch_key].append(data_pair[data_key]) if 'continuation_indices' in data_pair: batch['continuation_indices'].append( - data_pair['continuation_indices']) + data_pair['continuation_indices'], + ) batch = convert_tokens_to_tensors(batch, self.tokenize_labels) batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) @@ -497,7 +540,8 @@ def split_batch(self, batch: Any, # List split lists of strings if isinstance(microbatch_size, float): raise ValueError( - 'split_batch does not support floating point microbatch_size.') + 'split_batch does not support floating point microbatch_size.', + ) chunked = {} for k, v in batch.items(): if k in self.static_keys: @@ -514,14 +558,15 @@ def split_batch(self, batch: Any, if k in self.static_keys: chunked[k] = [v] * num_chunks - batched_list = [ - {k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks) - ] + batched_list = [{k: v[idx] + for k, v in chunked.items()} + for idx in range(num_chunks)] return batched_list -class InContextLearningGenerationTaskWithAnswersDataset(InContextLearningDataset - ): +class InContextLearningGenerationTaskWithAnswersDataset( + InContextLearningDataset, +): """A dataset that constructs batches for in-context learning generation. tasks with answers. Generation tasks evaluate a model's ability to @@ -540,32 +585,39 @@ class InContextLearningGenerationTaskWithAnswersDataset(InContextLearningDataset do_normalization (bool): Flag indicating whether to normalize generations before providing output. """ - def __init__(self, - cot_delimiter: str = '', - early_stopping_criteria: Optional[List[str]] = None, - do_normalization: bool = True, - *args: Any, - **kwargs: Any): + def __init__( + self, + cot_delimiter: str = '', + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True, + *args: Any, + **kwargs: Any, + ): if kwargs['tokenizer'].eos_token_id is None: raise ValueError( - '`InContextLearningGenerationTaskWithAnswersDataset` tokenizer must have non-null `eos_token_id`' + '`InContextLearningGenerationTaskWithAnswersDataset` tokenizer must have non-null `eos_token_id`', ) self.cot_delimiter = cot_delimiter self.has_cot = False self.max_answer_length = 0 static_keys = [ - 'mode', 'cot_delimiter', 'generation_kwargs', 'do_normalization', - 'stopping_criteria' + 'mode', + 'cot_delimiter', + 'generation_kwargs', + 'do_normalization', + 'stopping_criteria', ] tensor_keys = ['input_ids', 'attention_mask'] list_keys = ['labels'] - super().__init__(padding_side='left', - tokenize_labels=False, - static_keys=static_keys, - list_keys=list_keys, - tensor_keys=tensor_keys, - *args, - **kwargs) + super().__init__( + padding_side='left', + tokenize_labels=False, + static_keys=static_keys, + list_keys=list_keys, + tensor_keys=tensor_keys, + *args, + **kwargs, + ) # NOTE: set these after init call because they take class vars self.early_stopping_criteria = early_stopping_criteria self.base_batch = { @@ -579,7 +631,7 @@ def __init__(self, 'pad_token_id': self.pad_tok_id, 'use_cache': True, 'eos_token_id': self.tokenizer.eos_token_id, - 'max_new_tokens': max(self.max_answer_length, 1) + 'max_new_tokens': max(self.max_answer_length, 1), }, } self.batch_mapping = { @@ -596,8 +648,12 @@ def read_dataset( hf_loading_vars: Dict, hf_parsing_map: Dict, ) -> 'HFDataset': - dataset = super().read_dataset(dataset_uri, destination_path, - hf_loading_vars, hf_parsing_map) + dataset = super().read_dataset( + dataset_uri, + destination_path, + hf_loading_vars, + hf_parsing_map, + ) self.has_cot = 'chain_of_thought' in dataset.features dataset = dataset.map( lambda examples: { @@ -609,7 +665,8 @@ def read_dataset( set([examples['answer']] + examples.get('aliases', [])), 'chain_of_thought': examples.get('chain_of_thought', ''), - }) + }, + ) self.max_answer_length = self._get_max_answer_length(dataset) # NOTE: This is the only time we use the class variable padding_size. if self.max_seq_len < self.max_answer_length: @@ -619,9 +676,11 @@ def read_dataset( self.padding_size = self.max_seq_len - self.max_answer_length return dataset - def get_answer_from_example(self, - example: Dict, - in_context: bool = False) -> str: + def get_answer_from_example( + self, + example: Dict, + in_context: bool = False, + ) -> str: """Returns the answer from the example. Applies chain of thought if. self.has_cot is marked as true. @@ -637,8 +696,12 @@ def get_answer_from_example(self, else: return example[self.answer_key] - def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, - example: Dict) -> Dict[str, Any]: + def tokenize_example( + self, + prompt_and_fewshot: str, + ctxt: str, + example: Dict, + ) -> Dict[str, Any]: """Run text through the tokenizer and handle special cases. Args: @@ -649,8 +712,11 @@ def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, Returns: Dict: Dictionary with the tokenized data """ - tokenized_example = super().tokenize_example(prompt_and_fewshot, ctxt, - example) + tokenized_example = super().tokenize_example( + prompt_and_fewshot, + ctxt, + example, + ) tokenized_example['aliases'] = list(example.get('aliases', [])) return tokenized_example @@ -662,8 +728,9 @@ def _get_max_answer_length(self, dataset: Iterable[dict]) -> int: """ max_answer_length = 0 for example in dataset: - all_answers = [example[self.answer_key]] + list( - example.get('aliases', [])) + all_answers = [ + example[self.answer_key], + ] + list(example.get('aliases', [])) for answer in all_answers: if self.has_cot: response = ( @@ -673,10 +740,13 @@ def _get_max_answer_length(self, dataset: Iterable[dict]) -> int: response = answer tokenized_response = self.tokenizer(response)['input_ids'] assert isinstance(tokenized_response, list) - max_answer_length = max(max_answer_length, - len(tokenized_response)) + max_answer_length = max( + max_answer_length, + len(tokenized_response), + ) max_answer_length = max_answer_length + ( - _MAX_ANSWER_BUFFER_LENGTH if len(self.cot_delimiter) > 0 else 0) + _MAX_ANSWER_BUFFER_LENGTH if len(self.cot_delimiter) > 0 else 0 + ) return max_answer_length def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: @@ -688,9 +758,13 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: raise MissingConditionalImportError( extra_deps_group='nlp', conda_package='transformers', - conda_channel='conda-forge') + conda_channel='conda-forge', + ) stopping_criteria = stop_sequences_criteria( - self.tokenizer, self.early_stopping_criteria, batch_size) + self.tokenizer, + self.early_stopping_criteria, + batch_size, + ) batch['generation_kwargs']['stopping_criteria'] = stopping_criteria return batch @@ -709,25 +783,29 @@ class InContextLearningLMTaskDataset(InContextLearningDataset): """ def __init__(self, *args: Any, **kwargs: Any): - super().__init__(answer_key='continuation', - static_keys=['mode'], - tensor_keys=[ - 'input_ids', 'continuation_indices', 'labels', - 'attention_mask' - ], - base_batch={ - 'input_ids': [], - 'continuation_indices': [], - 'mode': 'icl_task', - 'labels': [] - }, - batch_mapping={ - 'input_ids': 'context', - 'labels': 'context' - }, - padding_side='right', - *args, - **kwargs) + super().__init__( + answer_key='continuation', + static_keys=['mode'], + tensor_keys=[ + 'input_ids', + 'continuation_indices', + 'labels', + 'attention_mask', + ], + base_batch={ + 'input_ids': [], + 'continuation_indices': [], + 'mode': 'icl_task', + 'labels': [], + }, + batch_mapping={ + 'input_ids': 'context', + 'labels': 'context', + }, + padding_side='right', + *args, + **kwargs, + ) class InContextLearningMultipleChoiceTaskDataset(InContextLearningDataset): @@ -756,14 +834,16 @@ class InContextLearningMultipleChoiceTaskDataset(InContextLearningDataset): choices_key (str): The key under which the choices are stored in the saved dataset. Defaults to 'choices'. """ - def __init__(self, - choices_key: str = 'choices', - static_keys: Optional[List] = None, - list_of_tensors_keys: Optional[List] = None, - list_of_tuples_keys: Optional[List] = None, - list_of_primitives: Optional[List] = None, - *args: Any, - **kwargs: Any): + def __init__( + self, + choices_key: str = 'choices', + static_keys: Optional[List] = None, + list_of_tensors_keys: Optional[List] = None, + list_of_tuples_keys: Optional[List] = None, + list_of_primitives: Optional[List] = None, + *args: Any, + **kwargs: Any, + ): self.choices_key = choices_key base_batch = { 'input_ids': [], @@ -775,30 +855,36 @@ def __init__(self, } context_key = kwargs.pop('context_key', 'query') static_keys = kwargs.pop('static_keys', ['mode', 'generation_kwargs']) - tensor_keys = kwargs.pop('tensor_keys', - ['input_ids', 'labels', 'attention_mask']) + tensor_keys = kwargs.pop( + 'tensor_keys', + ['input_ids', 'labels', 'attention_mask'], + ) self.list_of_tensors_keys = list_of_tensors_keys or [ - 'continuation_indices' + 'continuation_indices', ] self.list_of_tuples_keys = list_of_tuples_keys or ['choice_groupings'] self.list_of_primitives = list_of_primitives or ['gold_indices'] - super().__init__(context_key=context_key, - base_batch=base_batch, - static_keys=static_keys, - tensor_keys=tensor_keys, - padding_side='right', - *args, - **kwargs) + super().__init__( + context_key=context_key, + base_batch=base_batch, + static_keys=static_keys, + tensor_keys=tensor_keys, + padding_side='right', + *args, + **kwargs, + ) self.num_choices = len(self.dataset[0][self.choices_key]) self.batch_mapping_per_choice = { 'input_ids': 'context', - 'labels': 'context' + 'labels': 'context', } self.batch_map_per_example = {'gold_indices': 'gold'} - def get_answer_from_example(self, - example: Dict, - in_context: bool = False) -> str: + def get_answer_from_example( + self, + example: Dict, + in_context: bool = False, + ) -> str: """Returns the correct answer from the example's choices. Args: @@ -811,8 +897,12 @@ def get_answer_from_example(self, gold_idx = example['gold'] return choices[gold_idx] - def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, - example: Dict) -> Dict[str, Any]: + def tokenize_example( + self, + prompt_and_fewshot: str, + ctxt: str, + example: Dict, + ) -> Dict[str, Any]: """Runs text through the tokenizer and handle special cases. Args: @@ -834,7 +924,9 @@ def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, ctxt = ctxt.rstrip() # Never add special tokens to context tokenized_context = self.tokenizer( - ctxt, add_special_tokens=False)['input_ids'] + ctxt, + add_special_tokens=False, + )['input_ids'] assert isinstance(tokenized_context, list) tokenized_context = preamble + tokenized_context @@ -848,14 +940,21 @@ def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, # Never add special tokens to answer tokenized_answer = self.tokenizer( - choice, add_special_tokens=False)['input_ids'] + choice, + add_special_tokens=False, + )['input_ids'] assert isinstance(tokenized_context, list) assert isinstance(tokenized_answer, list) - trimmed_context = trim_context(tokenized_context, tokenized_answer, - self.padding_size) + trimmed_context = trim_context( + tokenized_context, + tokenized_answer, + self.padding_size, + ) assert isinstance(trimmed_context, list) continuation_indices = get_continuation_span( - trimmed_context, tokenized_answer) + trimmed_context, + tokenized_answer, + ) padded_context = make_padded_input( trimmed_context, tokenized_answer, @@ -867,7 +966,8 @@ def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, tokenized_example[self.context_key].append(padded_context) tokenized_example[self.answer_key].append(tokenized_answer) tokenized_example['continuation_indices'].append( - continuation_indices) + continuation_indices, + ) tokenized_example['gold'] = example['gold'] return tokenized_example @@ -896,7 +996,8 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: for i, context_enc in enumerate(data_pair[self.context_key]): batch['input_ids'].append(context_enc) batch['continuation_indices'].append( - data_pair['continuation_indices'][i]) + data_pair['continuation_indices'][i], + ) batch['labels'].append(context_enc) batch['gold_indices'].append(data_pair['gold']) @@ -930,7 +1031,8 @@ def split_batch(self, batch: Any, """ if isinstance(microbatch_size, float): raise ValueError( - 'split_batch does not support floating point microbatch_size.') + 'split_batch does not support floating point microbatch_size.', + ) chunked = {} for k, v in batch.items(): if k in self.static_keys: @@ -939,8 +1041,10 @@ def split_batch(self, batch: Any, elif type(v) == list: # list of tensors - 'continuation_indices' if k in self.list_of_tensors_keys: - chunked[k] = _split_list(v, - microbatch_size * self.num_choices) + chunked[k] = _split_list( + v, + microbatch_size * self.num_choices, + ) # list of tuples - 'choice_groupings' elif k in self.list_of_tuples_keys: chunked[k] = _split_list(v, microbatch_size) @@ -951,7 +1055,9 @@ def split_batch(self, batch: Any, raise ValueError(f'Unexpected key {k} in list splitting') elif k in self.tensor_keys: chunked[k] = _default_split_batch( - v, microbatch_size * self.num_choices) + v, + microbatch_size * self.num_choices, + ) else: raise ValueError(f'Unexpected key {k} in batch splitting') num_chunks = len(chunked['input_ids']) @@ -960,13 +1066,14 @@ def split_batch(self, batch: Any, if k in self.static_keys: chunked[k] = [v] * num_chunks - return [ - {k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks) - ] + return [{k: v[idx] + for k, v in chunked.items()} + for idx in range(num_chunks)] class InContextLearningSchemaTaskDataset( - InContextLearningMultipleChoiceTaskDataset): + InContextLearningMultipleChoiceTaskDataset, +): """A dataset that constructs batches for in-context learning schema. evaluation. A schema task involves sentences with a fill-in-the-blank where @@ -989,20 +1096,24 @@ class InContextLearningSchemaTaskDataset( - choice_groupings: Indicates which indices of the batch correspond to which questions """ - def __init__(self, - choices_key: str = 'context_options', - *args: Any, - **kwargs: Any): + def __init__( + self, + choices_key: str = 'context_options', + *args: Any, + **kwargs: Any, + ): static_keys = ['mode'] tensor_keys = ['input_ids', 'labels', 'attention_mask'] list_of_tensors_keys = ['continuation_indices'] - super().__init__(choices_key=choices_key, - context_key=choices_key, - static_keys=static_keys, - tensor_keys=tensor_keys, - list_of_tensors_keys=list_of_tensors_keys, - *args, - **kwargs) + super().__init__( + choices_key=choices_key, + context_key=choices_key, + static_keys=static_keys, + tensor_keys=tensor_keys, + list_of_tensors_keys=list_of_tensors_keys, + *args, + **kwargs, + ) self.base_batch = { 'input_ids': [], 'continuation_indices': [], @@ -1012,10 +1123,12 @@ def __init__(self, 'choice_groupings': [], } - def construct_context(self, - example: Dict[str, Any], - preceding_text: str = '', - add_answer: bool = False) -> str: + def construct_context( + self, + example: Dict[str, Any], + preceding_text: str = '', + add_answer: bool = False, + ) -> str: """Takes a example and constructs a context with the correct context. for. @@ -1039,9 +1152,11 @@ def construct_context(self, context = f'{self.prelimiter}{context}{self.continuation_delimiter}{continuation}' return context - def _construct_multiple_contexts(self, - example: Dict, - preceding_text: str = '') -> List[str]: + def _construct_multiple_contexts( + self, + example: Dict, + preceding_text: str = '', + ) -> List[str]: """Takes a example and constructs all contexts. Optionally, appends this to preceding text (such as a prompt or fewshot examples). @@ -1093,15 +1208,25 @@ def _prep_example( Dict: Contains a dictionary with the tokenized data """ prompt_and_fewshot = self._generate_few_shot_prompt( - num_fewshot, example_idx, prompt_string, fewshot_rng) + num_fewshot, + example_idx, + prompt_string, + fewshot_rng, + ) ctxt = self._construct_multiple_contexts(example, prompt_and_fewshot) - tokenized_example = self.tokenize_example(prompt_and_fewshot, ctxt, - example) + tokenized_example = self.tokenize_example( + prompt_and_fewshot, + ctxt, + example, + ) return tokenized_example - def tokenize_example(self, prompt_and_fewshot: str, - context_options: List[str], - example: Dict) -> Dict[str, Any]: + def tokenize_example( + self, + prompt_and_fewshot: str, + context_options: List[str], + example: Dict, + ) -> Dict[str, Any]: """Runs text through the tokenizer and handle special cases. Args: @@ -1125,10 +1250,14 @@ def tokenize_example(self, prompt_and_fewshot: str, ] continuation = example['continuation'] if self.prefix_space: - continuation = (f' {continuation}' if - not continuation.startswith(' ') else continuation) + continuation = ( + f' {continuation}' + if not continuation.startswith(' ') else continuation + ) tokenized_continuation = self.tokenizer( - continuation, add_special_tokens=False)['input_ids'] + continuation, + add_special_tokens=False, + )['input_ids'] tokenized_example[self.context_key] = [] tokenized_example['continuation_indices'] = [] @@ -1136,19 +1265,27 @@ def tokenize_example(self, prompt_and_fewshot: str, for context in encoded_contexts: assert isinstance(context, list) assert isinstance(tokenized_continuation, list) - trimmed_context = trim_context(context, tokenized_continuation, - self.padding_size) + trimmed_context = trim_context( + context, + tokenized_continuation, + self.padding_size, + ) assert isinstance(trimmed_context, list) continuation_indices = get_continuation_span( - trimmed_context, tokenized_continuation) - padded_context = make_padded_input(trimmed_context, - tokenized_continuation, - self.padding_size, - self.pad_tok_id, - self.padding_side) + trimmed_context, + tokenized_continuation, + ) + padded_context = make_padded_input( + trimmed_context, + tokenized_continuation, + self.padding_size, + self.pad_tok_id, + self.padding_side, + ) tokenized_example[self.context_key].append(padded_context) tokenized_example['continuation_indices'].append( - continuation_indices) + continuation_indices, + ) tokenized_example[self.answer_key].append(tokenized_continuation) tokenized_example['gold'] = example['gold'] @@ -1208,7 +1345,7 @@ def __init__( pass_at_k = [pass_at_k] if generations_per_sample < max(pass_at_k): raise ValueError( - f'generations_per_sample ({generations_per_sample}) must be greater than or equal to pass_at_k ({pass_at_k}) for code evaluation.' + f'generations_per_sample ({generations_per_sample}) must be greater than or equal to pass_at_k ({pass_at_k}) for code evaluation.', ) batch_mapping = { 'input_ids': 'prompt', @@ -1287,7 +1424,7 @@ def __init__( 'temperature': 0.2, # good default for code 'use_cache': True, 'eos_token_id': self.tokenizer.eos_token_id, - 'max_new_tokens': max(max_new_tokens, 1) + 'max_new_tokens': max(max_new_tokens, 1), }, 'sample_id': [], 'pass_at_k': list(pass_at_k), @@ -1332,7 +1469,8 @@ def _set_max_prompt_and_answer_lengths(self): tokenized_answer = self.tokenizer( example['canonical_solution'], - add_special_tokens=False)['input_ids'] + add_special_tokens=False, + )['input_ids'] assert isinstance(tokenized_answer, list) len_tokenized_answer = len(tokenized_answer) max_answer_length = max(max_answer_length, len_tokenized_answer) @@ -1356,21 +1494,32 @@ def _trim_padding(self, example: Dict): ] # Reapply padding only to max_prompt_length full_prompt = trim_context(unpadded_prompt, [], self.max_prompt_length) - padded_context = make_padded_input(full_prompt, [], - self.max_prompt_length, - self.pad_tok_id, self.padding_side) + padded_context = make_padded_input( + full_prompt, + [], + self.max_prompt_length, + self.pad_tok_id, + self.padding_side, + ) example[self.context_key] = padded_context return example - def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, - example: Dict) -> Dict[str, Any]: + def tokenize_example( + self, + prompt_and_fewshot: str, + ctxt: str, + example: Dict, + ) -> Dict[str, Any]: """Adds extra code task details to the example dictionary. See InContextLearningDataset for more details """ - tokenized_example = super().tokenize_example(prompt_and_fewshot, ctxt, - example) + tokenized_example = super().tokenize_example( + prompt_and_fewshot, + ctxt, + example, + ) tokenized_example['prompt_text'] = example['prompt'] tokenized_example['task_id'] = example['task_id'] tokenized_example['canonical_solution'] = example['canonical_solution'] @@ -1383,27 +1532,28 @@ def tokenize_example(self, prompt_and_fewshot: str, ctxt: str, def build_icl_dataloader( - icl_task_type: str, - dataset_uri: str, - tokenizer: transformers.PreTrainedTokenizerBase, - batch_size: int, - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, # e.g. 'translate english to french:' - example_delimiter: str, # e.g. '\n' - continuation_delimiter: str, # e.g. '' - hf_loading_vars: Dict, - hf_parsing_map: Dict, - destination_path: str, - prelimiter: str, # e.g. 'Question: ' - cot_delimiter: str, # e.g. ' ### ' - fewshot_random_seed: int, - pass_at_k: int, - generations_per_sample: int, - generation_kwargs: Dict, - early_stopping_criteria: Optional[List[str]] = None, - do_normalization: bool = True) -> DataSpec: + icl_task_type: str, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + batch_size: int, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + prompt_string: str, # e.g. 'translate english to french:' + example_delimiter: str, # e.g. '\n' + continuation_delimiter: str, # e.g. '' + hf_loading_vars: Dict, + hf_parsing_map: Dict, + destination_path: str, + prelimiter: str, # e.g. 'Question: ' + cot_delimiter: str, # e.g. ' ### ' + fewshot_random_seed: int, + pass_at_k: int, + generations_per_sample: int, + generation_kwargs: Dict, + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True, +) -> DataSpec: """Factory method that builds the specific dataset for the specified. icl_task_type. See documentation for `get_icl_task_dataloader` for argument @@ -1476,7 +1626,9 @@ def build_icl_dataloader( warnings.warn( VersionedDeprecationWarning( "ICL task type 'question_answering' is now deprecated. Use identifier 'generation_task_with_answers'", - 'v0.9.0')) + 'v0.9.0', + ), + ) dataset = InContextLearningGenerationTaskWithAnswersDataset( dataset_uri=dataset_uri, tokenizer=tokenizer, @@ -1501,7 +1653,9 @@ def build_icl_dataloader( warnings.warn( VersionedDeprecationWarning( "ICL task type 'code_evaluation' is deprecated and will no longer be supported. ", - 'v0.9.0')) + 'v0.9.0', + ), + ) dataset = InContextLearningCodeEvalDataset( dataset_uri=dataset_uri, tokenizer=tokenizer, @@ -1528,7 +1682,7 @@ def build_icl_dataloader( split_batch = None if isinstance( - dataset, + dataset, ( InContextLearningMultipleChoiceTaskDataset, InContextLearningGenerationTaskWithAnswersDataset, @@ -1550,9 +1704,12 @@ def build_icl_dataloader( ) -def partition_dataset_by_category(dataset_uri: str, destination_path: str, - hf_loading_vars: Dict, - hf_parsing_map: Dict) -> Dict[str, str]: +def partition_dataset_by_category( + dataset_uri: str, + destination_path: str, + hf_loading_vars: Dict, + hf_parsing_map: Dict, +) -> Dict[str, str]: """If has_categories is enabled, we partition the dataset into a separate. dataset for each category value in the data and write each partition to a @@ -1572,35 +1729,43 @@ def partition_dataset_by_category(dataset_uri: str, destination_path: str, if dataset_uri.startswith('hf://'): dataset_uri = dataset_uri.replace('hf://', '') dataset = load_dataset(dataset_uri, **hf_loading_vars) - assert isinstance(dataset, HFDataset) or isinstance( - dataset, IterableDataset) + assert isinstance(dataset, + HFDataset) or isinstance(dataset, IterableDataset) if hf_parsing_map: dataset_parsing_func = lambda example: { k: ' '.join([str(example[col]) for col in v]) for k, v in hf_parsing_map.items() } assert hasattr(dataset, 'column_names') - dataset = dataset.map(dataset_parsing_func, - remove_columns=dataset.column_names) + dataset = dataset.map( + dataset_parsing_func, + remove_columns=dataset.column_names, + ) else: with dist.local_rank_zero_download_and_wait(destination_path): if dist.get_local_rank() == 0: get_file(dataset_uri, destination_path, overwrite=True) - dataset = load_dataset('json', - data_files=destination_path, - split='train', - streaming=False) - assert isinstance(dataset, HFDataset) or isinstance(dataset, - IterableDataset) + dataset = load_dataset( + 'json', + data_files=destination_path, + split='train', + streaming=False, + ) + assert isinstance(dataset, + HFDataset) or isinstance(dataset, IterableDataset) assert hasattr(dataset, 'features') assert dataset.features is not None if 'category' not in dataset.features.keys(): - raise Exception(f"""Attempted to partition dataset by `category` \ + raise Exception( + f"""Attempted to partition dataset by `category` \ but it doesn't have a `category` key. \ - Got keys: {str(list(dataset.features.keys()))}""") + Got keys: {str(list(dataset.features.keys()))}""", + ) categories = sorted( - set(dataset['category'] - )) # pyright: ignore[reportIndexIssue, reportGeneralTypeIssues] + set( + dataset['category'], + ), + ) # pyright: ignore[reportIndexIssue, reportGeneralTypeIssues] output_files = {} for cat in categories: path = destination_path.split('/') @@ -1620,29 +1785,30 @@ def partition_dataset_by_category(dataset_uri: str, destination_path: str, def get_icl_task_dataloader( - icl_task_type: str, - dataset_uri: str, - tokenizer: Union[transformers.PreTrainedTokenizer, - transformers.PreTrainedTokenizerFast], - batch_size: int, - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, # e.g. 'translate english to french:' - example_delimiter: str, # e.g. '\n' - continuation_delimiter: str = '', - destination_path: str = '', - question_prelimiter: str = '', # e.g. 'Question: ' - fewshot_random_seed: int = 1234, - pass_at_k: int = 1, - generations_per_sample: int = 1, - cot_delimiter: str = '', - has_categories: bool = False, - hf_loading_vars: Optional[Dict] = None, - hf_parsing_map: Optional[Dict] = None, - generation_kwargs: Optional[Dict] = None, - early_stopping_criteria: Optional[List[str]] = None, - do_normalization: bool = True) -> Union[DataSpec, Dict[str, DataSpec]]: + icl_task_type: str, + dataset_uri: str, + tokenizer: Union[transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast], + batch_size: int, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + prompt_string: str, # e.g. 'translate english to french:' + example_delimiter: str, # e.g. '\n' + continuation_delimiter: str = '', + destination_path: str = '', + question_prelimiter: str = '', # e.g. 'Question: ' + fewshot_random_seed: int = 1234, + pass_at_k: int = 1, + generations_per_sample: int = 1, + cot_delimiter: str = '', + has_categories: bool = False, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + early_stopping_criteria: Optional[List[str]] = None, + do_normalization: bool = True, +) -> Union[DataSpec, Dict[str, DataSpec]]: r"""Constructs a dataloader (or dataloaders if has_categories is True) capable of evaluating LLMs on in-context learning language modeling tasks, @@ -1735,10 +1901,12 @@ def get_icl_task_dataloader( if has_categories: result_dls = {} - output_files = partition_dataset_by_category(dataset_uri, - destination_path, - hf_loading_vars, - hf_parsing_map) + output_files = partition_dataset_by_category( + dataset_uri, + destination_path, + hf_loading_vars, + hf_parsing_map, + ) categories = sorted(output_files.keys()) for category in categories: partition_uri = output_files[category] diff --git a/llmfoundry/eval/datasets/utils.py b/llmfoundry/eval/datasets/utils.py index 6433e7cb56..1ce249437d 100644 --- a/llmfoundry/eval/datasets/utils.py +++ b/llmfoundry/eval/datasets/utils.py @@ -43,7 +43,8 @@ def strip_data(example: Dict) -> Dict: def tokenizer_needs_prefix_space( - tokenizer: transformers.PreTrainedTokenizerBase) -> bool: + tokenizer: transformers.PreTrainedTokenizerBase, +) -> bool: """Test for whether a prefix space is needed before the continuation. Sentencepiece tokenization should not have a prefix space, but gpt2 style @@ -60,8 +61,11 @@ def tokenizer_needs_prefix_space( return len(test_tokens) == 1 -def trim_context(context_enc: List, continuation_enc: List, - max_seq_len: int) -> List: +def trim_context( + context_enc: List, + continuation_enc: List, + max_seq_len: int, +) -> List: """Trims a list of tokens down to `max_seq_len` if the length of the list. plus the continuation is more than `max_seq_len`. It will always trim tokens @@ -81,15 +85,18 @@ def trim_context(context_enc: List, continuation_enc: List, if context_max_subseq_len < 0: # can't support continuations which are longer than the max seq len raise Exception( - f'Dataset included continuation longer than the max seq len') + f'Dataset included continuation longer than the max seq len', + ) # clip from the end context_enc = context_enc[-(context_max_subseq_len):] return context_enc -def get_continuation_span(context_enc: List, - continuation_enc: List) -> torch.Tensor: +def get_continuation_span( + context_enc: List, + continuation_enc: List, +) -> torch.Tensor: """Gets the list of indices of the continuation tokens for language. modeling. @@ -105,14 +112,17 @@ def get_continuation_span(context_enc: List, """ return torch.tensor( range(len(context_enc), - len(context_enc) + len(continuation_enc))) + len(context_enc) + len(continuation_enc)), + ) -def make_padded_input(context_enc: List, - continuation_enc: List, - max_seq_len: int, - pad_tok_id: int, - padding_side: str = 'right') -> torch.Tensor: +def make_padded_input( + context_enc: List, + continuation_enc: List, + max_seq_len: int, + pad_tok_id: int, + padding_side: str = 'right', +) -> torch.Tensor: """Takes an encoded context and continuation and clips the beginning of the. context if they're too long. Adds the padding token to the specified side. @@ -138,7 +148,7 @@ def make_padded_input(context_enc: List, # token and cause errors if not isinstance(pad_tok_id, int): raise ValueError( - f'`pad_tok_id` must be an integer. Found {type(pad_tok_id)} instead' + f'`pad_tok_id` must be an integer. Found {type(pad_tok_id)} instead', ) # pad length from seq to padding_length if padding_side == 'right': @@ -159,7 +169,7 @@ def make_padded_input(context_enc: List, ) else: raise ValueError( - f"Unknown padding_side {padding_side}. padding_side must be either 'left' or 'right'" + f"Unknown padding_side {padding_side}. padding_side must be either 'left' or 'right'", ) return inp @@ -181,17 +191,23 @@ def convert_tokens_to_tensors(batch: Dict, Returns: dict: The batch with torch tensors in the corresponding keys instead of lists of lists """ - batch['input_ids'] = torch.stack(list(map(torch.tensor, - batch['input_ids']))) + batch['input_ids'] = torch.stack( + list(map(torch.tensor, batch['input_ids'])), + ) if tokenize_labels: batch['labels'] = torch.stack(list(map(torch.tensor, batch['labels']))) batch['continuation_indices'] = list( - map(torch.tensor, batch['continuation_indices'])) + map(torch.tensor, batch['continuation_indices']), + ) return batch -def get_fewshot_sample_idxs(dataset_size: int, num_fewshot: int, - example_idx: int, rng: random.Random) -> Set[int]: +def get_fewshot_sample_idxs( + dataset_size: int, + num_fewshot: int, + example_idx: int, + rng: random.Random, +) -> Set[int]: """Samples indices without replacement. If num_fewshot exceeds the number. of unique examples in the dataset, then we will have fewer than num_fewshot examples in context. @@ -234,8 +250,10 @@ def __init__( ) -> None: self.done_tracker = [False] * batch_size self.stop_sequence = stop_sequence - self.stop_sequence_ids = tokenizer.encode(stop_sequence, - add_special_tokens=False) + self.stop_sequence_ids = tokenizer.encode( + stop_sequence, + add_special_tokens=False, + ) # sentence piece tokenizers add a superfluous underline token before string-initial \n # that throws off our calculation of the stop sequence length @@ -252,10 +270,12 @@ def __init__( self.stop_sequence_id_len = len(self.stop_sequence_ids) + 1 self.tokenizer = tokenizer - def __call__(self, - input_ids: torch.LongTensor, - scores: Optional[torch.FloatTensor] = None, - **kwargs: Dict[str, Any]) -> bool: + def __call__( + self, + input_ids: torch.LongTensor, + scores: Optional[torch.FloatTensor] = None, + **kwargs: Dict[str, Any], + ) -> bool: # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence lookback_ids_batch = input_ids[:, :][:, -self.stop_sequence_id_len:] lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) diff --git a/llmfoundry/eval/metrics/__init__.py b/llmfoundry/eval/metrics/__init__.py index 079439da59..6a50fcb484 100644 --- a/llmfoundry/eval/metrics/__init__.py +++ b/llmfoundry/eval/metrics/__init__.py @@ -5,10 +5,13 @@ from llmfoundry.eval.metrics.nlp import ( InContextLearningCodeEvalAccuracy, - InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningGenerationExactMatchAccuracy, + InContextLearningLMAccuracy, InContextLearningLMExpectedCalibrationError, - InContextLearningMCExpectedCalibrationError, InContextLearningMetric, - InContextLearningMultipleChoiceAccuracy) + InContextLearningMCExpectedCalibrationError, + InContextLearningMetric, + InContextLearningMultipleChoiceAccuracy, +) __all__ = [ 'InContextLearningMetric', diff --git a/llmfoundry/eval/metrics/nlp.py b/llmfoundry/eval/metrics/nlp.py index f5a50721e3..a7764a0d0a 100644 --- a/llmfoundry/eval/metrics/nlp.py +++ b/llmfoundry/eval/metrics/nlp.py @@ -15,9 +15,12 @@ import numpy as np import torch from composer.utils import dist -from composer.utils.eval_client import (EvalClient, LambdaEvalClient, - LocalEvalClient, - MosaicMLLambdaEvalClient) +from composer.utils.eval_client import ( + EvalClient, + LambdaEvalClient, + LocalEvalClient, + MosaicMLLambdaEvalClient, +) from torch import Tensor from torch.nn import functional as F from torchmetrics import Metric @@ -125,9 +128,11 @@ class InContextLearningGenerationExactMatchAccuracy(InContextLearningMetric): def __init__(self, dist_sync_on_step: bool = False): # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) - self.add_state('correct', - default=torch.tensor(0.), - dist_reduce_fx='sum') + self.add_state( + 'correct', + default=torch.tensor(0.), + dist_reduce_fx='sum', + ) self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum') self.metric_result_dict = { 'cleaned_output': [], @@ -149,8 +154,9 @@ def white_space_fix(text: str) -> str: return ' '.join(text.split()) def handle_punc(text: str) -> str: - exclude = set(string.punctuation + - ''.join([u'‘', u'’', u'´', u'`'])) + exclude = set( + string.punctuation + ''.join([u'‘', u'’', u'´', u'`']), + ) return ''.join(ch if ch not in exclude else ' ' for ch in text) def lower(text: str) -> str: @@ -160,8 +166,8 @@ def replace_underscore(text: str) -> str: return text.replace('_', ' ') return white_space_fix( - remove_articles(handle_punc(lower( - replace_underscore(answer))))).strip() + remove_articles(handle_punc(lower(replace_underscore(answer)))), + ).strip() def update( self, @@ -177,8 +183,10 @@ def update( final_answer = sample_output if stopping_criteria is not None and len(stopping_criteria) > 0: - final_answer = re.split('|'.join(stopping_criteria), - final_answer)[0] + final_answer = re.split( + '|'.join(stopping_criteria), + final_answer, + )[0] if cot_delimiter is not None and len(cot_delimiter) > 0: final_answer = final_answer.split(cot_delimiter)[-1] @@ -199,8 +207,9 @@ def update( metric_result_dict['cleaned_label'].append(cleaned_sample_labels) if any( - cleaned_final_answer.startswith(label) - for label in cleaned_sample_labels): + cleaned_final_answer.startswith(label) + for label in cleaned_sample_labels + ): self.correct += torch.tensor(1.0) metric_result_dict['result'].append(1) else: @@ -243,29 +252,35 @@ class InContextLearningLMAccuracy(InContextLearningMetric): def __init__(self, dist_sync_on_step: bool = False): # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) - self.add_state('correct', - default=torch.tensor(0.), - dist_reduce_fx='sum') + self.add_state( + 'correct', + default=torch.tensor(0.), + dist_reduce_fx='sum', + ) self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum') self.metric_result_dict = { 'context': [], 'label': [], 'output': [], - 'result': [] + 'result': [], } def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): metric_result_dict = copy.deepcopy(self.metric_result_dict) for batch_idx, cont_idx in enumerate(batch['continuation_indices']): - cont_tok_pred = outputs[batch_idx].index_select(dim=0, - index=cont_idx - - 1).argmax(dim=-1) - cont_tok_targ = labels[batch_idx].index_select(dim=0, - index=cont_idx - 1) + cont_tok_pred = outputs[batch_idx].index_select( + dim=0, + index=cont_idx - 1, + ).argmax(dim=-1) + cont_tok_targ = labels[batch_idx].index_select( + dim=0, + index=cont_idx - 1, + ) metric_result_dict['context'].append( - batch['input_ids'][batch_idx][:cont_idx[0]]) + batch['input_ids'][batch_idx][:cont_idx[0]], + ) metric_result_dict['label'].append(cont_tok_targ) metric_result_dict['output'].append(cont_tok_pred) @@ -308,9 +323,11 @@ class InContextLearningMultipleChoiceAccuracy(InContextLearningMetric): def __init__(self, dist_sync_on_step: bool = False): # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) - self.add_state('correct', - default=torch.tensor(0.0), - dist_reduce_fx='sum') + self.add_state( + 'correct', + default=torch.tensor(0.0), + dist_reduce_fx='sum', + ) self.add_state('total', default=torch.tensor(0.0), dist_reduce_fx='sum') self.metric_result_dict = { 'context': [], @@ -327,19 +344,24 @@ def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): perplexities = [] for batch_idx, cont_idx in enumerate(batch['continuation_indices']): # continuation indices refer to indices in the original input's token space - cont_tok_logits = outputs[batch_idx].index_select(dim=0, - index=cont_idx - - 1) + cont_tok_logits = outputs[batch_idx].index_select( + dim=0, + index=cont_idx - 1, + ) # labels have been shifted left by one index, so the cont_idx needs to be shifted as well. - cont_tok_targ = labels[batch_idx].index_select(dim=0, - index=cont_idx - 1) + cont_tok_targ = labels[batch_idx].index_select( + dim=0, + index=cont_idx - 1, + ) cross_entropy = F.cross_entropy(cont_tok_logits, cont_tok_targ) perplexity = torch.exp(cross_entropy) perplexities.append(perplexity) metric_result_dict = copy.deepcopy(self.metric_result_dict) - for (start, end), gold_idx in zip(batch['choice_groupings'], - batch['gold_indices']): + for (start, end), gold_idx in zip( + batch['choice_groupings'], + batch['gold_indices'], + ): subset = perplexities[start:end] idx_min = subset.index(min(subset)) @@ -424,7 +446,7 @@ def __init__(self, dist_sync_on_step: bool = False): 'context': [], 'output': [], 'result': [], - 'sample_id': [] + 'sample_id': [], } def get_client(self) -> EvalClient: @@ -435,7 +457,8 @@ def get_client(self) -> EvalClient: 'Running code eval locally may be insecure. Please set environment variable CODE_EVAL_DEVICE ' + 'to LAMBDA to run on remote. To use Lambdas, spin up your instance that checks code, set the URL as ' - + 'CODE_EVAL_URL and the API key as CODE_EVAL_APIKEY.') + + 'CODE_EVAL_URL and the API key as CODE_EVAL_APIKEY.', + ) log.debug('Running code eval locally.') client = LocalEvalClient() elif self.eval_device == 'LAMBDA': @@ -449,11 +472,13 @@ def get_client(self) -> EvalClient: 'variable `CODE_EVAL_DEVICE` is not set. Please set it to `CODE_EVAL_DEVICE` ' + 'to one of `LOCAL` (for unsafe local eval), `LAMBDA` (for AWS lambda ' - + 'evaluation), or `MOSAICML` (for lambda eval through MAPI).') + + 'evaluation), or `MOSAICML` (for lambda eval through MAPI).', + ) else: raise ValueError( 'Environment variable `CODE_EVAL_DEVICE` must be one of `LOCAL`, ' - + f'`LAMBDA`, or `MOSAICML` but got {self.eval_device}.') + + f'`LAMBDA`, or `MOSAICML` but got {self.eval_device}.', + ) return client @@ -476,17 +501,25 @@ def _initialize_state(self, batch: dict[str, Any]): self.num_generations = batch['generations_per_sample'] # We need to defer the accumulator initialization because it depends on dataset size - self.add_state('correct', - default=torch.zeros(self.dataset_size, device=device), - dist_reduce_fx='sum') - self.add_state('total', - default=torch.zeros(self.dataset_size, device=device), - dist_reduce_fx='sum') + self.add_state( + 'correct', + default=torch.zeros(self.dataset_size, device=device), + dist_reduce_fx='sum', + ) + self.add_state( + 'total', + default=torch.zeros(self.dataset_size, device=device), + dist_reduce_fx='sum', + ) dist.barrier() self._initialized = True - def update(self, batch: Dict[str, Any], outputs: List[str], - labels: List[str]): + def update( + self, + batch: Dict[str, Any], + outputs: List[str], + labels: List[str], + ): """Updates the pass@k accuracy of code generation. Given a batch of prompts, test cases, and code generations, evaluates the code generations @@ -518,17 +551,21 @@ def update(self, batch: Dict[str, Any], outputs: List[str], metric_result_dict = copy.deepcopy(self.metric_result_dict) for sample_id, code_gen, sample_prompt, test_inputs, test_outputs, entry_point, language in zip( - batch['sample_id'], outputs, batch['prompts'], - batch['test_inputs'], batch['test_outputs'], - batch['entry_points'], batch['languages']): + batch['sample_id'], + outputs, + batch['prompts'], + batch['test_inputs'], + batch['test_outputs'], + batch['entry_points'], + batch['languages'], + ): idx = sample_id self.total[idx] += 1.0 metric_result_dict['sample_id'].append(sample_id) - code_gen = re.split( - r'\n[A-Za-z0-9#`]', - code_gen)[0] # remove everything after function ends + code_gen = re.split(r'\n[A-Za-z0-9#`]', code_gen)[ + 0] # remove everything after function ends final_code = sample_prompt + code_gen # combine prompt with the code generation metric_result_dict['context'].append(sample_prompt) metric_result_dict['output'].append(code_gen) @@ -564,12 +601,12 @@ def compute(self): warnings.warn( 'Some samples in the dataset have less than the expected number of generations. ' + - 'This is expected if you are using a subset of the dataset for evaluation.' + 'This is expected if you are using a subset of the dataset for evaluation.', ) if (self.correct > self.total).any().item(): raise ValueError( - 'Internal error some samples have more correct than total generations. This should not happen.' + 'Internal error some samples have more correct than total generations. This should not happen.', ) results = {} @@ -615,12 +652,16 @@ def __init__(self, dist_sync_on_step: bool = False, n_buckets: int = 10): self.n_buckets = n_buckets if n_buckets < 1: raise Exception('`n_buckets`') - self.add_state('bucket_totals', - default=torch.zeros(n_buckets), - dist_reduce_fx='sum') - self.add_state('bucket_correct', - default=torch.zeros(n_buckets), - dist_reduce_fx='sum') + self.add_state( + 'bucket_totals', + default=torch.zeros(n_buckets), + dist_reduce_fx='sum', + ) + self.add_state( + 'bucket_correct', + default=torch.zeros(n_buckets), + dist_reduce_fx='sum', + ) def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): pass @@ -646,7 +687,8 @@ def compute(self): class InContextLearningMCExpectedCalibrationError( - InContextLearningExpectedCalibrationError): + InContextLearningExpectedCalibrationError, +): r"""Computes Expected Calibration Error (ECE) for In-context learning (ICL) multiple choice (MC) tasks. (source: https://arxiv.org/abs/2012.00955). @@ -664,17 +706,24 @@ def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): outputs = torch.softmax(outputs, dim=2) probabilities = [] for batch_idx, cont_idx in enumerate(batch['continuation_indices']): - cont_tok_logits = outputs[batch_idx].index_select(dim=0, - index=cont_idx - - 1) - cont_tok_targ = labels[batch_idx].index_select(dim=0, - index=cont_idx - 1) + cont_tok_logits = outputs[batch_idx].index_select( + dim=0, + index=cont_idx - 1, + ) + cont_tok_targ = labels[batch_idx].index_select( + dim=0, + index=cont_idx - 1, + ) probability = cont_tok_logits.index_select( - dim=1, index=cont_tok_targ).diagonal().mean() + dim=1, + index=cont_tok_targ, + ).diagonal().mean() probabilities.append(probability) - for (start, end), gold_idx in zip(batch['choice_groupings'], - batch['gold_indices']): + for (start, end), gold_idx in zip( + batch['choice_groupings'], + batch['gold_indices'], + ): subset = probabilities[start:end] idx_max = subset.index(max(subset)) confidence = torch.tensor(subset).max() / torch.tensor(subset).sum() @@ -686,14 +735,16 @@ def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): if idx_max == gold_idx: self.bucket_correct[ - bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] + bucket_idx + ] += 1 # pyright: ignore [reportGeneralTypeIssues] self.bucket_totals[ bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] class InContextLearningLMExpectedCalibrationError( - InContextLearningExpectedCalibrationError): + InContextLearningExpectedCalibrationError, +): r"""Computes Expected Calibration Error (ECE) for In-context learning (ICL) language modeling (LM) tasks. (cite: https://arxiv.org/pdf/1706.04599.pdf). @@ -710,13 +761,16 @@ def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): outputs = torch.softmax(outputs, dim=2) for batch_idx, cont_idx in enumerate(batch['continuation_indices']): - cont_tok_logits = outputs[batch_idx].index_select(dim=0, - index=cont_idx - - 1) + cont_tok_logits = outputs[batch_idx].index_select( + dim=0, + index=cont_idx - 1, + ) cont_tok_pred = cont_tok_logits.argmax(dim=-1) confidence = cont_tok_logits.max(dim=-1).values.min() - cont_tok_targ = labels[batch_idx].index_select(dim=0, - index=cont_idx - 1) + cont_tok_targ = labels[batch_idx].index_select( + dim=0, + index=cont_idx - 1, + ) assert confidence >= 0.0 and confidence <= 1.0 bucket_idx = int(confidence * self.n_buckets) if bucket_idx == self.n_buckets: @@ -724,7 +778,8 @@ def update(self, batch: dict, outputs: torch.Tensor, labels: torch.Tensor): if (cont_tok_pred == cont_tok_targ).all(): self.bucket_correct[ - bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] + bucket_idx + ] += 1 # pyright: ignore [reportGeneralTypeIssues] self.bucket_totals[ bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] diff --git a/llmfoundry/interfaces/callback_with_config.py b/llmfoundry/interfaces/callback_with_config.py index e30d1793c4..3579174ae7 100644 --- a/llmfoundry/interfaces/callback_with_config.py +++ b/llmfoundry/interfaces/callback_with_config.py @@ -15,7 +15,11 @@ class CallbackWithConfig(Callback, abc.ABC): its other kwargs. """ - def __init__(self, config: dict[str, Any], *args: Any, - **kwargs: Any) -> None: + def __init__( + self, + config: dict[str, Any], + *args: Any, + **kwargs: Any, + ) -> None: del config, args, kwargs pass diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index 24593144aa..e618d03dc8 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -10,88 +10,109 @@ _norm_description = ( 'The norms registry is used to register classes that implement normalization layers.' ) -norms = create_registry('llmfoundry', - 'norms', - generic_type=Type[torch.nn.Module], - entry_points=True, - description=_norm_description) +norms = create_registry( + 'llmfoundry', + 'norms', + generic_type=Type[torch.nn.Module], + entry_points=True, + description=_norm_description, +) _fc_description = ( 'The fully connected layers registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).' + 'These classes should take in_features and out_features in as args, at a minimum.' ) -fcs = create_registry('llmfoundry', - 'fcs', - generic_type=Type[torch.nn.Module], - entry_points=True, - description=_fc_description) +fcs = create_registry( + 'llmfoundry', + 'fcs', + generic_type=Type[torch.nn.Module], + entry_points=True, + description=_fc_description, +) _ffns_description = ( 'The ffns registry is used to register functions that build ffn layers.' + - 'See ffn.py for examples.') -ffns = create_registry('llmfoundry', - 'ffns', - generic_type=Callable, - entry_points=True, - description=_ffns_description) + 'See ffn.py for examples.' +) +ffns = create_registry( + 'llmfoundry', + 'ffns', + generic_type=Callable, + entry_points=True, + description=_ffns_description, +) _ffns_with_norm_description = ( 'The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.' - + 'See ffn.py for examples.') -ffns_with_norm = create_registry('llmfoundry', - 'ffns_with_norm', - generic_type=Callable, - entry_points=True, - description=_ffns_with_norm_description) + + 'See ffn.py for examples.' +) +ffns_with_norm = create_registry( + 'llmfoundry', + 'ffns_with_norm', + generic_type=Callable, + entry_points=True, + description=_ffns_with_norm_description, +) _ffns_with_megablocks_description = ( 'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.' - + 'See ffn.py for examples.') + + 'See ffn.py for examples.' +) ffns_with_megablocks = create_registry( 'llmfoundry', 'ffns_with_megablocks', generic_type=Callable, entry_points=True, - description=_ffns_with_megablocks_description) + description=_ffns_with_megablocks_description, +) _attention_classes_description = ( 'The attention_classes registry is used to register classes that implement attention layers. See ' - + 'attention.py for expected constructor signature.') -attention_classes = create_registry('llmfoundry', - 'attention_classes', - generic_type=Type[torch.nn.Module], - entry_points=True, - description=_attention_classes_description) + + 'attention.py for expected constructor signature.' +) +attention_classes = create_registry( + 'llmfoundry', + 'attention_classes', + generic_type=Type[torch.nn.Module], + entry_points=True, + description=_attention_classes_description, +) _attention_implementations_description = ( 'The attention_implementations registry is used to register functions that implement the attention operation.' - + 'See attention.py for expected function signature.') + + 'See attention.py for expected function signature.' +) attention_implementations = create_registry( 'llmfoundry', 'attention_implementations', generic_type=Callable, entry_points=True, - description=_attention_implementations_description) + description=_attention_implementations_description, +) _param_init_fns_description = ( 'The param_init_fns registry is used to register functions that initialize parameters.' + 'These will be called on a module to initialize its parameters. See param_init_fns.py for examples.' ) -param_init_fns = create_registry('llmfoundry', - 'param_init_fns', - generic_type=Callable[..., None], - entry_points=True, - description=_param_init_fns_description) +param_init_fns = create_registry( + 'llmfoundry', + 'param_init_fns', + generic_type=Callable[..., None], + entry_points=True, + description=_param_init_fns_description, +) _module_init_fns_description = """The module_init_fns registry is used to register functions that initialize specific modules. These functions should return True if they initialize the module, and False otherwise. This allows them to be called without knowing their contents. They should take in the module, init_div_is_residual, and div_is_residual arguments.""" -module_init_fns = create_registry('llmfoundry', - 'module_init_fns', - generic_type=Callable[..., bool], - entry_points=True, - description=_module_init_fns_description) +module_init_fns = create_registry( + 'llmfoundry', + 'module_init_fns', + generic_type=Callable[..., bool], + entry_points=True, + description=_module_init_fns_description, +) __all__ = [ 'norms', diff --git a/llmfoundry/loggers/__init__.py b/llmfoundry/loggers/__init__.py index 4a8f75b35a..cd3f3fdc62 100644 --- a/llmfoundry/loggers/__init__.py +++ b/llmfoundry/loggers/__init__.py @@ -1,14 +1,20 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from composer.loggers import (InMemoryLogger, MLFlowLogger, TensorboardLogger, - WandBLogger) +from composer.loggers import ( + InMemoryLogger, + MLFlowLogger, + TensorboardLogger, + WandBLogger, +) from llmfoundry.registry import loggers loggers.register('wandb', func=WandBLogger) loggers.register('tensorboard', func=TensorboardLogger) loggers.register('inmemory', func=InMemoryLogger) -loggers.register('in_memory_logger', - func=InMemoryLogger) # for backwards compatibility +loggers.register( + 'in_memory_logger', + func=InMemoryLogger, +) # for backwards compatibility loggers.register('mlflow', func=MLFlowLogger) diff --git a/llmfoundry/metrics/__init__.py b/llmfoundry/metrics/__init__.py index 8ca2db5bd2..18067c3283 100644 --- a/llmfoundry/metrics/__init__.py +++ b/llmfoundry/metrics/__init__.py @@ -1,27 +1,38 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from composer.metrics import (LanguageCrossEntropy, LanguagePerplexity, - MaskedAccuracy) +from composer.metrics import ( + LanguageCrossEntropy, + LanguagePerplexity, + MaskedAccuracy, +) from llmfoundry.eval.metrics import ( InContextLearningCodeEvalAccuracy, - InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, + InContextLearningGenerationExactMatchAccuracy, + InContextLearningLMAccuracy, InContextLearningLMExpectedCalibrationError, InContextLearningMCExpectedCalibrationError, - InContextLearningMultipleChoiceAccuracy) + InContextLearningMultipleChoiceAccuracy, +) from llmfoundry.metrics.token_acc import TokenAccuracy from llmfoundry.registry import metrics metrics.register('token_accuracy', func=TokenAccuracy) metrics.register('lm_accuracy', func=InContextLearningLMAccuracy) -metrics.register('lm_expected_calibration_error', - func=InContextLearningLMExpectedCalibrationError) -metrics.register('mc_expected_calibration_error', - func=InContextLearningMCExpectedCalibrationError) +metrics.register( + 'lm_expected_calibration_error', + func=InContextLearningLMExpectedCalibrationError, +) +metrics.register( + 'mc_expected_calibration_error', + func=InContextLearningMCExpectedCalibrationError, +) metrics.register('mc_accuracy', func=InContextLearningMultipleChoiceAccuracy) -metrics.register('qa_accuracy', - func=InContextLearningGenerationExactMatchAccuracy) +metrics.register( + 'qa_accuracy', + func=InContextLearningGenerationExactMatchAccuracy, +) metrics.register('code_eval_accuracy', func=InContextLearningCodeEvalAccuracy) metrics.register('language_cross_entropy', func=LanguageCrossEntropy) metrics.register('language_perplexity', func=LanguagePerplexity) diff --git a/llmfoundry/metrics/token_acc.py b/llmfoundry/metrics/token_acc.py index 1cdcffe1db..6843220530 100644 --- a/llmfoundry/metrics/token_acc.py +++ b/llmfoundry/metrics/token_acc.py @@ -25,17 +25,23 @@ class TokenAccuracy(Metric): # Ensures torchmetrics calls update only once full_state_update = False - def __init__(self, - ignore_index: int = -100, - dist_sync_on_step: bool = False): + def __init__( + self, + ignore_index: int = -100, + dist_sync_on_step: bool = False, + ): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_index = ignore_index - self.add_state('correct_tokens', - default=torch.tensor(0), - dist_reduce_fx='sum') - self.add_state('total_tokens', - default=torch.tensor(0), - dist_reduce_fx='sum') + self.add_state( + 'correct_tokens', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) + self.add_state( + 'total_tokens', + default=torch.tensor(0), + dist_reduce_fx='sum', + ) def update(self, preds: torch.Tensor, target: torch.Tensor): """Updates the internal state with results from a new batch. diff --git a/llmfoundry/models/__init__.py b/llmfoundry/models/__init__.py index ea144225c0..827fe2ce56 100644 --- a/llmfoundry/models/__init__.py +++ b/llmfoundry/models/__init__.py @@ -2,12 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.hf import ComposerHFCausalLM, ComposerHFT5 -from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper, - FMAPIChatAPIEvalWrapper, - OpenAICausalLMEvalWrapper, - OpenAIChatAPIEvalWrapper) -from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig, - MPTForCausalLM, MPTModel, MPTPreTrainedModel) +from llmfoundry.models.inference_api_wrapper import ( + FMAPICasualLMEvalWrapper, + FMAPIChatAPIEvalWrapper, + OpenAICausalLMEvalWrapper, + OpenAIChatAPIEvalWrapper, +) +from llmfoundry.models.mpt import ( + ComposerMPTCausalLM, + MPTConfig, + MPTForCausalLM, + MPTModel, + MPTPreTrainedModel, +) from llmfoundry.registry import models models.register('mpt_causal_lm', func=ComposerMPTCausalLM) diff --git a/llmfoundry/models/hf/__init__.py b/llmfoundry/models/hf/__init__.py index 2ed7b2d6e1..b34281cd81 100644 --- a/llmfoundry/models/hf/__init__.py +++ b/llmfoundry/models/hf/__init__.py @@ -2,9 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM -from llmfoundry.models.hf.hf_fsdp import (prepare_hf_causal_lm_model_for_fsdp, - prepare_hf_enc_dec_model_for_fsdp, - prepare_hf_model_for_fsdp) +from llmfoundry.models.hf.hf_fsdp import ( + prepare_hf_causal_lm_model_for_fsdp, + prepare_hf_enc_dec_model_for_fsdp, + prepare_hf_model_for_fsdp, +) from llmfoundry.models.hf.hf_t5 import ComposerHFT5 from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index d225bfbd5f..80b62936ed 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -12,11 +12,18 @@ from composer.utils import dist from omegaconf import DictConfig from torchmetrics import Metric -from transformers import (AutoConfig, AutoModelForCausalLM, PretrainedConfig, - PreTrainedModel, PreTrainedTokenizerBase) - -from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS, - DEFAULT_CAUSAL_LM_TRAIN_METRICS) +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizerBase, +) + +from llmfoundry.metrics import ( + DEFAULT_CAUSAL_LM_EVAL_METRICS, + DEFAULT_CAUSAL_LM_TRAIN_METRICS, +) from llmfoundry.models.hf.hf_fsdp import hf_get_init_device from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP from llmfoundry.models.layers.attention import is_flash_v2_installed @@ -60,20 +67,26 @@ class ComposerHFCausalLM(HuggingFaceModelWithFSDP): tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ - def __init__(self, om_model_config: DictConfig, - tokenizer: PreTrainedTokenizerBase): + def __init__( + self, + om_model_config: DictConfig, + tokenizer: PreTrainedTokenizerBase, + ): model = ComposerHFCausalLM.build_inner_model(om_model_config) train_metrics, eval_metrics = ComposerHFCausalLM.build_metrics( - om_model_config) + om_model_config, + ) - peft_config_dict = pop_config(om_model_config, - 'peft_config', - must_exist=False, - convert=True) + peft_config_dict = pop_config( + om_model_config, + 'peft_config', + must_exist=False, + convert=True, + ) if peft_config_dict is not None and not peft_installed: raise ValueError( - 'PEFT is not installed, but peft_config was passed. Please install LLM Foundry with the peft extra to use peft_config.' + 'PEFT is not installed, but peft_config was passed. Please install LLM Foundry with the peft extra to use peft_config.', ) peft_config = None @@ -95,7 +108,8 @@ def __init__(self, om_model_config: DictConfig, @staticmethod def build_metrics( - om_model_config: DictConfig) -> Tuple[List[Metric], List[Metric]]: + om_model_config: DictConfig, + ) -> Tuple[List[Metric], List[Metric]]: """Builds the training and evaluation metrics for the model. Args: @@ -105,12 +119,16 @@ def build_metrics( use_train_metrics = om_model_config.get('use_train_metrics', True) train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + om_model_config.get( - 'additional_train_metrics', []) + 'additional_train_metrics', + [], + ) train_metrics = [ build_metric(metric, {}) for metric in train_metric_names ] if use_train_metrics else [] eval_metric_names = DEFAULT_CAUSAL_LM_EVAL_METRICS + om_model_config.get( - 'additional_eval_metrics', []) + 'additional_eval_metrics', + [], + ) eval_metrics = [ build_metric(metric, {}) for metric in eval_metric_names ] @@ -119,8 +137,8 @@ def build_metrics( @staticmethod def build_inner_model( - om_model_config: DictConfig, - prepare_for_fsdp: bool = False + om_model_config: DictConfig, + prepare_for_fsdp: bool = False, ) -> Union[PreTrainedModel, 'PeftModel']: """Builds the inner model for the ComposerHFCausalLM. @@ -130,22 +148,27 @@ def build_inner_model( """ pretrained_model_name_or_path = om_model_config.pretrained_model_name_or_path pretrained_lora_id_or_path = om_model_config.get( - 'pretrained_lora_id_or_path', None) + 'pretrained_lora_id_or_path', + None, + ) if not om_model_config.get( - 'trust_remote_code', True + 'trust_remote_code', + True, ) and pretrained_model_name_or_path.startswith('mosaicml/mpt'): raise ValueError( 'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, ' + - 'which is significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.' + 'which is significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.', ) # Set up Hugging Face args trust_remote_code = om_model_config.get('trust_remote_code', True) use_auth_token = om_model_config.get('use_auth_token', False) - use_flash_attention_2 = om_model_config.get('use_flash_attention_2', - False) + use_flash_attention_2 = om_model_config.get( + 'use_flash_attention_2', + False, + ) load_in_8bit = om_model_config.get('load_in_8bit', False) # Set up config args for the model construction and base classes @@ -157,7 +180,8 @@ def build_inner_model( if use_flash_attention_2 and not is_flash_v2_installed(): raise ValueError( 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. ' - + 'Please `pip install llm-foundry[gpu]`.') + + 'Please `pip install llm-foundry[gpu]`.', + ) # Construct the Hugging Face config to use config = AutoConfig.from_pretrained( @@ -174,21 +198,23 @@ def build_inner_model( # the model and then casting it back to fp32, we are monkeypatching their check. # https://github.com/huggingface/transformers/issues/28052 def _autoset_attn_implementation_monkeypatch( - cls, # type: ignore - config, # type: ignore - *args, # type: ignore - **kwargs): # type: ignore + cls, # type: ignore + config, # type: ignore + *args, # type: ignore + **kwargs, # type: ignore + ): # type: ignore config._attn_implementation = requested_attention_implementation return config PreTrainedModel._autoset_attn_implementation = classmethod( - _autoset_attn_implementation_monkeypatch) + _autoset_attn_implementation_monkeypatch, + ) # set config overrides for k, v in om_model_config.get('config_overrides', {}).items(): if not hasattr(config, k): raise ValueError( - f'config does not have attribute "{k}" to override ({k}: {v}).' + f'config does not have attribute "{k}" to override ({k}: {v}).', ) attr = getattr(config, k) @@ -199,7 +225,8 @@ def _autoset_attn_implementation_monkeypatch( raise ValueError( f'Config dict override got unknown keys. ' + f'Extra keys: {extra_keys}. ' + - f'Expected (a subset of) keys: {list(attr.keys())}.') + f'Expected (a subset of) keys: {list(attr.keys())}.', + ) getattr(config, k).update(v) # necessary case to allow for rope_scaling to be overriden in llama config elif attr is None and isinstance(v, Mapping): @@ -208,13 +235,13 @@ def _autoset_attn_implementation_monkeypatch( elif isinstance(attr, PretrainedConfig): if not isinstance(v, Mapping): raise ValueError( - f'Expected a dictionary for config override {k}, but got {v}.' + f'Expected a dictionary for config override {k}, but got {v}.', ) for _k, _v in v.items(): if not hasattr(attr, _k): raise ValueError( - f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).' + f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).', ) setattr(attr, _k, _v) else: @@ -229,8 +256,8 @@ def _autoset_attn_implementation_monkeypatch( # transformers modules cache. On particular systems, this operation seems to cause contention between # the different processes. To avoid this contention, we first create the model (on meta device) on local rank # zero. This will set up the transformers model cache and avoid the future contention. - if dist.get_local_rank() == 0 and os.path.isdir( - pretrained_model_name_or_path): + if dist.get_local_rank( + ) == 0 and os.path.isdir(pretrained_model_name_or_path): with init_empty_weights(include_buffers=False): with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) @@ -261,7 +288,7 @@ def _autoset_attn_implementation_monkeypatch( elif resolved_init_device == 'meta': if om_model_config.pretrained: raise ValueError( - 'Setting cfg.pretrained=True is not supported when init_device="meta".' + 'Setting cfg.pretrained=True is not supported when init_device="meta".', ) with init_empty_weights(include_buffers=False): model = AutoModelForCausalLM.from_config( @@ -270,7 +297,8 @@ def _autoset_attn_implementation_monkeypatch( ) else: raise ValueError( - f'init_device="{init_device}" must be either "cpu" or "meta".') + f'init_device="{init_device}" must be either "cpu" or "meta".', + ) signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed' if dist.get_local_rank() == 0: @@ -294,11 +322,13 @@ def _autoset_attn_implementation_monkeypatch( if pretrained_lora_id_or_path is not None: if not peft_installed: raise ValueError( - 'PEFT is not installed, but lora_id_or_path was passed. Please install LLM Foundry with the peft extra to use lora_id_or_path.' + 'PEFT is not installed, but lora_id_or_path was passed. Please install LLM Foundry with the peft extra to use lora_id_or_path.', ) from peft import PeftModelForCausalLM model = PeftModelForCausalLM.from_pretrained( - model, pretrained_lora_id_or_path) + model, + pretrained_lora_id_or_path, + ) if prepare_for_fsdp: ComposerHFCausalLM.prepare_inner_model(model, init_device) @@ -311,15 +341,15 @@ def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig': peft_type = peft_config_dict.get('peft_type', '') if peft_type.upper() != 'LORA': raise ValueError( - f'Only LORA is supported for peft_type, but got {peft_type}.' + f'Only LORA is supported for peft_type, but got {peft_type}.', ) task_type = peft_config_dict.get('task_type', '') if task_type.upper() != 'CAUSAL_LM': raise ValueError( - f'Only CAUSAL_LM is supported for task_type, but got {task_type}.' + f'Only CAUSAL_LM is supported for task_type, but got {task_type}.', ) return LoraConfig(**peft_config_dict) else: raise ValueError( - 'PEFT is not installed, but peft_config was passed. Please install LLM Foundry with the peft extra to use peft_config.' + 'PEFT is not installed, but peft_config was passed. Please install LLM Foundry with the peft extra to use peft_config.', ) diff --git a/llmfoundry/models/hf/hf_fsdp.py b/llmfoundry/models/hf/hf_fsdp.py index 87bffc3af8..00dada5532 100644 --- a/llmfoundry/models/hf/hf_fsdp.py +++ b/llmfoundry/models/hf/hf_fsdp.py @@ -75,12 +75,16 @@ def hf_get_causal_base_model(model: PreTrainedModel) -> Any: if hasattr(model, 'get_decoder'): return model.get_decoder() - decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox', - 'model.transformer') + decoder_attrs = ( + 'transformer', + 'model.decoder', + 'gpt_neox', + 'model.transformer', + ) causal_base_model = findattr(model, decoder_attrs) if causal_base_model is None: raise ValueError( - f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.' + f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.', ) return causal_base_model @@ -106,7 +110,7 @@ def hf_get_hidden_layers(model: PreTrainedModel) -> Any: layers = findattr(model, hidden_layers_attrs) if layers is None: raise ValueError( - f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}' + f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}', ) return layers @@ -124,8 +128,10 @@ def hf_get_init_device(init_device: Optional[str]) -> Optional[str]: # /end helper functions -def prepare_hf_model_for_fsdp(model: PreTrainedModel, - init_device: Optional[str]) -> None: +def prepare_hf_model_for_fsdp( + model: PreTrainedModel, + init_device: Optional[str], +) -> None: """FSDP wrap a HuggingFace model. Call specific functions @@ -138,9 +144,10 @@ def prepare_hf_model_for_fsdp(model: PreTrainedModel, prepare_hf_causal_lm_model_for_fsdp(model, init_device) -def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, - 'PeftModel'], - init_device: Optional[str]) -> None: +def prepare_hf_causal_lm_model_for_fsdp( + model: Union[PreTrainedModel, 'PeftModel'], + init_device: Optional[str], +) -> None: """FSDP wrap a HuggingFace decoder. Wrap any model for FSDP which follows one of the 3 existing conventions from @@ -165,14 +172,15 @@ def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, 'base_model': causal_base_model, 'model_block': model_block, 'lm_head': lm_head, - 'tied_embeddings': tied_embeddings + 'tied_embeddings': tied_embeddings, } for mod_name, module in modules.items(): if module is None: raise ValueError( f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + - 'follow common layer/weight naming conventions.') + 'follow common layer/weight naming conventions.', + ) block_type = type(model_block[0]) # When using the HF LM models, @@ -194,7 +202,8 @@ def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, active_adapters = [adapter.lower() for adapter in model.active_adapters] for name, module in model.named_modules(): if peft_type in name.lower() and any( - adapter in name.lower() for adapter in active_adapters): + adapter in name.lower() for adapter in active_adapters + ): has_parameters = next(module.parameters(), None) is not None has_buffers = next(module.buffers(), None) is not None if has_parameters or has_buffers: @@ -203,11 +212,15 @@ def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, # FSDP Wrap and Activation Checkpoint every model block model.fsdp_wrap_fn = lambda module: isinstance(module, block_type) model.activation_checkpointing_fn = lambda module: isinstance( - module, block_type) + module, + block_type, + ) -def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, - init_device: Optional[str]) -> None: +def prepare_hf_enc_dec_model_for_fsdp( + model: PreTrainedModel, + init_device: Optional[str], +) -> None: """Wrap an encoder/decoder HF model. This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet) @@ -230,14 +243,15 @@ def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, 'encoder_block': encoder_block, 'decoder_block': decoder_block, 'lm_head': lm_head, - 'tied_embeddings': tied_embeddings + 'tied_embeddings': tied_embeddings, } for mod_name, module in modules.items(): if module is None: raise ValueError( f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + - 'follow common layer/weight naming conventions.') + 'follow common layer/weight naming conventions.', + ) decoder_block_type = type(decoder_block[0]) encoder_block_type = type(encoder_block[0]) @@ -251,7 +265,9 @@ def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, # FSDP Wrap and Activation Checkpoint every decoder block model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type) model.activation_checkpointing_fn = lambda module: isinstance( - module, decoder_block_type) + module, + decoder_block_type, + ) if encoder_block_type == decoder_block_type: return @@ -259,4 +275,6 @@ def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, # need to wrap encoder blocks separately for ProphetNet and Marian model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type) model.activation_checkpointing_fn = lambda module: isinstance( - module, encoder_block_type) + module, + encoder_block_type, + ) diff --git a/llmfoundry/models/hf/hf_t5.py b/llmfoundry/models/hf/hf_t5.py index b9c1df64cf..6520fe7426 100644 --- a/llmfoundry/models/hf/hf_t5.py +++ b/llmfoundry/models/hf/hf_t5.py @@ -9,8 +9,11 @@ from composer.utils import dist from omegaconf import DictConfig -from transformers import (AutoConfig, PreTrainedTokenizerBase, - T5ForConditionalGeneration) +from transformers import ( + AutoConfig, + PreTrainedTokenizerBase, + T5ForConditionalGeneration, +) from llmfoundry.metrics import DEFAULT_ENC_DEC_METRICS from llmfoundry.models.hf.hf_fsdp import hf_get_init_device @@ -44,8 +47,11 @@ class ComposerHFT5(HuggingFaceModelWithFSDP): tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ - def __init__(self, om_model_config: DictConfig, - tokenizer: PreTrainedTokenizerBase): + def __init__( + self, + om_model_config: DictConfig, + tokenizer: PreTrainedTokenizerBase, + ): from llmfoundry.utils.builders import build_metric config = AutoConfig.from_pretrained( @@ -58,7 +64,7 @@ def __init__(self, om_model_config: DictConfig, for k, v in om_model_config.get('config_overrides', {}).items(): if not hasattr(config, k): raise ValueError( - f'config does not have attribute "{k}" to override ({k}: {v}).' + f'config does not have attribute "{k}" to override ({k}: {v}).', ) attr = getattr(config, k) @@ -68,7 +74,8 @@ def __init__(self, om_model_config: DictConfig, raise ValueError( f'Config dict override got unknown keys. ' + f'Extra keys: {extra_keys}. ' + - f'Expected (a subset of) keys: {list(attr.keys())}.') + f'Expected (a subset of) keys: {list(attr.keys())}.', + ) getattr(config, k).update(v) else: setattr(config, k, v) @@ -92,28 +99,30 @@ def __init__(self, om_model_config: DictConfig, if om_model_config.pretrained: model = T5ForConditionalGeneration.from_pretrained( om_model_config.pretrained_model_name_or_path, - config=config) + config=config, + ) else: model = T5ForConditionalGeneration(config) elif resolved_init_device == 'meta': if om_model_config.pretrained: raise ValueError( - 'Setting cfg.pretrained=True is not supported when init_device="meta".' + 'Setting cfg.pretrained=True is not supported when init_device="meta".', ) with init_empty_weights(include_buffers=False): model = T5ForConditionalGeneration(config) else: raise ValueError( - f'init_device="{init_device}" must be either "cpu" or "meta".') + f'init_device="{init_device}" must be either "cpu" or "meta".', + ) metrics = [ build_metric(metric, {}) for metric in DEFAULT_ENC_DEC_METRICS + om_model_config.get('additional_train_metrics', []) ] - composer_model = super().__init__(model=model, - tokenizer=tokenizer, - metrics=metrics, - init_device=init_device) - - return composer_model + super().__init__( + model=model, + tokenizer=tokenizer, + metrics=metrics, + init_device=init_device, + ) diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index 3bfa5d2ad3..c667c6026a 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -31,14 +31,16 @@ class HuggingFaceModelWithFSDP(HuggingFaceModel): Handles preparation for FSDP wrapping. """ - def __init__(self, - model: Union[transformers.PreTrainedModel, 'PeftModel'], - tokenizer: Optional[PreTrainedTokenizerBase] = None, - metrics: Optional[List[Metric]] = None, - eval_metrics: Optional[List[Metric]] = None, - shift_labels: bool = False, - init_device: Optional[str] = None, - peft_config: Optional['PeftConfig'] = None): + def __init__( + self, + model: Union[transformers.PreTrainedModel, 'PeftModel'], + tokenizer: Optional[PreTrainedTokenizerBase] = None, + metrics: Optional[List[Metric]] = None, + eval_metrics: Optional[List[Metric]] = None, + shift_labels: bool = False, + init_device: Optional[str] = None, + peft_config: Optional['PeftConfig'] = None, + ): super().__init__( model, tokenizer, @@ -61,7 +63,7 @@ def forward(self, batch: Mapping): output = self.model(**batch) # type: ignore (thirdparty) else: raise ValueError( - 'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model' + 'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model', ) return output @@ -72,9 +74,10 @@ def loss(self, outputs: ModelOutput, batch: Mapping): return outputs[:2] @staticmethod - def prepare_inner_model(model: Union[transformers.PreTrainedModel, - 'PeftModel'], - init_device: Optional[str] = None): + def prepare_inner_model( + model: Union[transformers.PreTrainedModel, 'PeftModel'], + init_device: Optional[str] = None, + ): """Prepare the inner model for FSDP wrapping. Args: diff --git a/llmfoundry/models/inference_api_wrapper/__init__.py b/llmfoundry/models/inference_api_wrapper/__init__.py index 905abf2fa1..936c711ad6 100644 --- a/llmfoundry/models/inference_api_wrapper/__init__.py +++ b/llmfoundry/models/inference_api_wrapper/__init__.py @@ -2,11 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.inference_api_wrapper.fmapi import ( - FMAPICasualLMEvalWrapper, FMAPIChatAPIEvalWrapper, FMAPIEvalInterface) + FMAPICasualLMEvalWrapper, + FMAPIChatAPIEvalWrapper, + FMAPIEvalInterface, +) from llmfoundry.models.inference_api_wrapper.interface import \ InferenceAPIEvalWrapper from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( - OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAIEvalInterface) + OpenAICausalLMEvalWrapper, + OpenAIChatAPIEvalWrapper, + OpenAIEvalInterface, +) __all__ = [ 'OpenAICausalLMEvalWrapper', diff --git a/llmfoundry/models/inference_api_wrapper/fmapi.py b/llmfoundry/models/inference_api_wrapper/fmapi.py index d0c987304a..a2c78800a9 100644 --- a/llmfoundry/models/inference_api_wrapper/fmapi.py +++ b/llmfoundry/models/inference_api_wrapper/fmapi.py @@ -10,7 +10,10 @@ from transformers import AutoTokenizer from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( - OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAIEvalInterface) + OpenAICausalLMEvalWrapper, + OpenAIChatAPIEvalWrapper, + OpenAIEvalInterface, +) __all__ = [ 'FMAPICasualLMEvalWrapper', @@ -38,27 +41,29 @@ def block_until_ready(self, base_url: str): break except requests.exceptions.ConnectionError: log.debug( - f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds' + f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds', ) time.sleep(sleep_s) waited_s += sleep_s if waited_s >= timeout_s: raise TimeoutError( - f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting' + f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting', ) def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer): is_local = om_model_config.pop('local', False) if is_local: - base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT', - 'http://0.0.0.0:8080/v2') + base_url = os.environ.get( + 'MOSAICML_MODEL_ENDPOINT', + 'http://0.0.0.0:8080/v2', + ) om_model_config['base_url'] = base_url self.block_until_ready(base_url) if 'base_url' not in om_model_config: raise ValueError( - 'Must specify base_url or use local=True in model_cfg for FMAPIsEvalWrapper' + 'Must specify base_url or use local=True in model_cfg for FMAPIsEvalWrapper', ) super().__init__(om_model_config, tokenizer) diff --git a/llmfoundry/models/inference_api_wrapper/interface.py b/llmfoundry/models/inference_api_wrapper/interface.py index a939d03d68..6d231441ae 100644 --- a/llmfoundry/models/inference_api_wrapper/interface.py +++ b/llmfoundry/models/inference_api_wrapper/interface.py @@ -54,8 +54,10 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None): # model's generate function. Extra generation kwargs can be passed in via the batch. Strings will # be returned from eval_forward output_logits_batch = [] - for tokens, cont_idxs in zip(batch['input_ids'], - batch['continuation_indices']): + for tokens, cont_idxs in zip( + batch['input_ids'], + batch['continuation_indices'], + ): seqlen = tokens.shape[0] tokens = tokens.tolist() @@ -63,20 +65,24 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None): expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1] output_logits = torch.nn.functional.one_hot( torch.tensor(tokens[1:cont_idxs[0]]), - num_classes=len(self.tokenizer)) + num_classes=len(self.tokenizer), + ) for i in range(len(expected_cont_tokens)): # decode one token at a time - prompt = self.tokenizer.decode(tokens[:cont_idxs[0]] + - expected_cont_tokens[0:i]) + prompt = self.tokenizer.decode( + tokens[:cont_idxs[0]] + expected_cont_tokens[0:i], + ) next_logit_tensor = self.get_next_token_logit_tensor(prompt) if next_logit_tensor is None: continue - output_logits = torch.cat( - [output_logits, - next_logit_tensor.reshape(1, -1)]) + output_logits = torch.cat([ + output_logits, + next_logit_tensor.reshape(1, -1), + ]) padding = torch.nn.functional.one_hot( torch.full((seqlen - output_logits.shape[0],), padding_tok), - num_classes=len(self.tokenizer)) + num_classes=len(self.tokenizer), + ) output_logits = torch.cat([output_logits, padding]) output_logits_batch.append(output_logits) @@ -87,18 +93,21 @@ def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None: self.labels = batch.pop('labels') self.labels[:, :-1] = self.labels[:, 1:].clone() self.labels[:, -1] = -100 - if isinstance(metric, InContextLearningMetric) and batch.get( - 'mode', None) == 'icl_task': + if isinstance( + metric, + InContextLearningMetric, + ) and batch.get('mode', None) == 'icl_task': assert self.labels is not None metric.update(batch, outputs, self.labels) else: raise NotImplementedError( - 'Inference API wrapper only supports InContextLearningMetrics and mode=icl_task' + 'Inference API wrapper only supports InContextLearningMetrics and mode=icl_task', ) def forward(self): raise NotImplementedError( - "Inference API wrapper doesn't support forward") + "Inference API wrapper doesn't support forward", + ) def loss(self): raise NotImplementedError("Inference API wrapper doesn't support loss") diff --git a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py index 9f2cf3315c..fb26c7990c 100644 --- a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +++ b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py @@ -36,8 +36,11 @@ class OpenAIEvalInterface(InferenceAPIEvalWrapper): - def __init__(self, om_model_config: DictConfig, - tokenizer: AutoTokenizer) -> None: + def __init__( + self, + om_model_config: DictConfig, + tokenizer: AutoTokenizer, + ) -> None: super().__init__(om_model_config, tokenizer) try: import openai @@ -45,7 +48,8 @@ def __init__(self, om_model_config: DictConfig, raise MissingConditionalImportError( extra_deps_group='openai', conda_package='openai', - conda_channel='conda-forge') from e + conda_channel='conda-forge', + ) from e api_key = os.environ.get('OPENAI_API_KEY') base_url = om_model_config.get('base_url') @@ -53,13 +57,13 @@ def __init__(self, om_model_config: DictConfig, # Using OpenAI default, where the API key is required if api_key is None: raise ValueError( - 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.' + 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.', ) else: # Using a custom base URL, where the API key may not be required log.info( - f'Making request to custom base URL: {base_url}{"" if api_key is not None else " (no API key set)"}' + f'Making request to custom base URL: {base_url}{"" if api_key is not None else " (no API key set)"}', ) api_key = 'placeholder' # This cannot be None @@ -86,7 +90,8 @@ def try_generate_completion(self, prompt: str, num_tokens: int): raise MissingConditionalImportError( extra_deps_group='openai', conda_package='openai', - conda_channel='conda-forge') from e + conda_channel='conda-forge', + ) from e tries = 0 completion = None delay = 1 @@ -97,7 +102,8 @@ def try_generate_completion(self, prompt: str, num_tokens: int): break except RateLimitError as e: if 'You exceeded your current quota' in str( - e._message): # pyright: ignore + e._message, + ): # pyright: ignore raise e delay *= 2 * (1 + random.random()) sleep(delay) @@ -112,8 +118,11 @@ def try_generate_completion(self, prompt: str, num_tokens: int): class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface): - def __init__(self, om_model_config: DictConfig, - tokenizer: AutoTokenizer) -> None: + def __init__( + self, + om_model_config: DictConfig, + tokenizer: AutoTokenizer, + ) -> None: super().__init__(om_model_config, tokenizer) self.generate_completion = lambda prompt, num_tokens: self.client.chat.completions.create( @@ -122,14 +131,17 @@ def __init__(self, om_model_config: DictConfig, 'role': 'system', 'content': - om_model_config.get('system_role_prompt', - 'Please complete the following text: ') + om_model_config.get( + 'system_role_prompt', + 'Please complete the following text: ', + ), }, { 'role': 'user', - 'content': prompt + 'content': prompt, }], max_tokens=num_tokens, - temperature=0.0) + temperature=0.0, + ) def retokenize(self, tokens: List[int], cont_idxs: List[int]): """Chat API will never respond with a word-initial space. @@ -139,24 +151,33 @@ def retokenize(self, tokens: List[int], cont_idxs: List[int]): """ original_len = len(tokens) retokenized_continuation = self.tokenizer( - self.tokenizer.decode(tokens[cont_idxs[0]:cont_idxs[-1] + - 1]).strip())['input_ids'] + self.tokenizer.decode( + tokens[cont_idxs[0]:cont_idxs[-1] + 1], + ).strip(), + )['input_ids'] # replace the original continuation with the retokenized continuation + padding padding = [tokens[-1]] * ( - len(tokens) - len(tokens[:cont_idxs[0]] + retokenized_continuation)) + len(tokens) - len(tokens[:cont_idxs[0]] + retokenized_continuation) + ) tokens = tokens[:cont_idxs[0]] + retokenized_continuation + padding if len(tokens) > original_len: # this only happens if we were already at max seq len and the continuation got LARGER tokens = tokens[-original_len:] cont_idxs = list( - range(original_len - len(retokenized_continuation), - original_len)) + range( + original_len - len(retokenized_continuation), + original_len, + ), + ) else: cont_idxs = list( - range(cont_idxs[0], - cont_idxs[0] + len(retokenized_continuation))) + range( + cont_idxs[0], + cont_idxs[0] + len(retokenized_continuation), + ), + ) return torch.tensor(tokens), torch.tensor(cont_idxs) def rebatch(self, batch: Batch): @@ -168,12 +189,16 @@ def rebatch(self, batch: Batch): new_batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = { 'input_ids': [], 'continuation_indices': [], - 'labels': [] + 'labels': [], } - for tokens, cont_idxs in zip(batch['input_ids'], - batch['continuation_indices']): - tokens, cont_idxs = self.retokenize(tokens.tolist(), - cont_idxs.tolist()) + for tokens, cont_idxs in zip( + batch['input_ids'], + batch['continuation_indices'], + ): + tokens, cont_idxs = self.retokenize( + tokens.tolist(), + cont_idxs.tolist(), + ) assert isinstance(new_batch['input_ids'], list) new_batch['input_ids'].append(tokens) @@ -199,8 +224,10 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None): padding_tok = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id output_logits_batch = [] batch = self.rebatch(batch) - for tokens, cont_idxs in zip(batch['input_ids'], - batch['continuation_indices']): + for tokens, cont_idxs in zip( + batch['input_ids'], + batch['continuation_indices'], + ): seqlen = tokens.shape[0] tokens = tokens.tolist() @@ -208,17 +235,21 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None): expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1] output_logits = torch.nn.functional.one_hot( torch.tensor(tokens[1:cont_idxs[0]]), - num_classes=len(self.tokenizer)) + num_classes=len(self.tokenizer), + ) prompt = self.tokenizer.decode(tokens[:cont_idxs[0]]) next_logit_tensor = self.get_next_token_logit_tensor( - prompt, num_tokens=len(expected_cont_tokens)) + prompt, + num_tokens=len(expected_cont_tokens), + ) if next_logit_tensor is not None: output_logits = torch.cat([output_logits, next_logit_tensor]) padding = torch.nn.functional.one_hot( torch.full((seqlen - output_logits.shape[0],), padding_tok), - num_classes=len(self.tokenizer)) + num_classes=len(self.tokenizer), + ) output_logits = torch.cat([output_logits, padding]) output_logits_batch.append(output_logits) @@ -231,7 +262,8 @@ def process_result(self, completion: Optional['ChatCompletion']): if len(completion.choices) > 0: tensors = [] for t in self.tokenizer( - completion.choices[0].message.content)['input_ids']: + completion.choices[0].message.content, + )['input_ids']: # Not real logprobs tensor = torch.tensor([0] * (len(self.tokenizer))) tensor[t] = 1.0 @@ -248,15 +280,19 @@ def process_result(self, completion: Optional['ChatCompletion']): class OpenAICausalLMEvalWrapper(OpenAIEvalInterface): - def __init__(self, om_model_config: DictConfig, - tokenizer: AutoTokenizer) -> None: + def __init__( + self, + om_model_config: DictConfig, + tokenizer: AutoTokenizer, + ) -> None: super().__init__(om_model_config, tokenizer) self.generate_completion = lambda prompt, num_tokens: self.client.completions.create( model=self.model_name, prompt=prompt, max_tokens=num_tokens, logprobs=5, - temperature=0.0) + temperature=0.0, + ) def process_result(self, completion: Optional['Completion']): if completion is None: @@ -270,7 +306,8 @@ def process_result(self, completion: Optional['Completion']): if len(completion.choices[0].logprobs.top_logprobs[0]) > 0: # Construct tensor of shape (vocab_size,) with logprobs for each token tokenizer_logprobs = dict( - completion.choices[0].logprobs.top_logprobs[0]) + completion.choices[0].logprobs.top_logprobs[0], + ) tensor = torch.tensor([min(tokenizer_logprobs.values()) - 1] * (len(self.tokenizer))) for k in tokenizer_logprobs: diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index e31029024c..92f1ea7aec 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -2,19 +2,34 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.layers.attention import ( - GroupedQueryAttention, MultiheadAttention, MultiQueryAttention, - attn_bias_shape, build_alibi_bias, build_attn_bias, check_alibi_support, - flash_attn_fn, scaled_multihead_dot_product_attention) + GroupedQueryAttention, + MultiheadAttention, + MultiQueryAttention, + attn_bias_shape, + build_alibi_bias, + build_attn_bias, + check_alibi_support, + flash_attn_fn, + scaled_multihead_dot_product_attention, +) from llmfoundry.models.layers.blocks import FusedNormAttentionNorm, MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.dmoe import DroplessMLP, LearnedRouter, dMoE from llmfoundry.models.layers.fc import * from llmfoundry.models.layers.ffn import MPTGLU, MPTMLP -from llmfoundry.models.layers.layer_builders import (build_attention_layer, - build_fc, build_ffn, - build_norm) -from llmfoundry.models.layers.norm import (LPLayerNorm, LPRMSNorm, RMSNorm, - TritonRMSNorm, rms_norm) +from llmfoundry.models.layers.layer_builders import ( + build_attention_layer, + build_fc, + build_ffn, + build_norm, +) +from llmfoundry.models.layers.norm import ( + LPLayerNorm, + LPRMSNorm, + RMSNorm, + TritonRMSNorm, + rms_norm, +) __all__ = [ 'scaled_multihead_dot_product_attention', diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index d4a34eecaa..82fee68af6 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -14,8 +14,10 @@ from packaging import version from torch import nn -from llmfoundry.layers_registry import (attention_classes, - attention_implementations) +from llmfoundry.layers_registry import ( + attention_classes, + attention_implementations, +) from llmfoundry.models.layers.layer_builders import build_fc, build_norm __all__ = [ @@ -54,20 +56,24 @@ def is_transformers_version_gte(hf_version: str) -> bool: def check_alibi_support(attention_impl: str) -> bool: return attention_impl != 'flash' or is_flash_v2_installed( - v2_version='v2.4.2') + v2_version='v2.4.2', + ) from transformers.models.llama.modeling_llama import apply_rotary_pos_emb -def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, - original_is_causal: bool) -> bool: +def _reset_is_causal( + num_query_tokens: int, + num_key_tokens: int, + original_is_causal: bool, +) -> bool: # disable causal when it is not needed # necessary for flash for generation with kv_cache if original_is_causal and num_query_tokens != num_key_tokens: if num_query_tokens != 1: raise NotImplementedError( - 'MPT does not support query and key with different number of tokens, unless number of query tokens is 1.' + 'MPT does not support query and key with different number of tokens, unless number of query tokens is 1.', ) else: return False @@ -147,11 +153,10 @@ def scaled_multihead_dot_product_attention( _s_k = max(0, attn_bias.size(3) - s_k) attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if (attn_bias.size(-1) != 1 and - attn_bias.size(-1) != s_k) or (attn_bias.size(-2) != 1 and - attn_bias.size(-2) != s_q): + if (attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k + ) or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q): raise RuntimeError( - f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.' + f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.', ) attn_weight = attn_weight + attn_bias @@ -164,10 +169,12 @@ def scaled_multihead_dot_product_attention( 'and applying it within the attention module can cause ' +\ 'unnecessary computation/memory usage. Consider integrating ' +\ 'into attn_bias once and passing that to each attention ' +\ - 'module instead.' + 'module instead.', ) attn_weight = attn_weight.masked_fill( - ~key_padding_mask.view((b, 1, 1, s_k)), min_val) + ~key_padding_mask.view((b, 1, 1, s_k)), + min_val, + ) if is_causal and (not q.size(2) == 1): s = max(s_q, s_k) @@ -176,16 +183,20 @@ def scaled_multihead_dot_product_attention( causal_mask = causal_mask.to(torch.bool) causal_mask = ~causal_mask causal_mask = causal_mask[-s_q:, -s_k:] - attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), - min_val) + attn_weight = attn_weight.masked_fill( + causal_mask.view(1, 1, s_q, s_k), + min_val, + ) attn_weight = torch.softmax(attn_weight, dim=-1) if dropout_p: - attn_weight = torch.nn.functional.dropout(attn_weight, - p=dropout_p, - training=training, - inplace=True) + attn_weight = torch.nn.functional.dropout( + attn_weight, + p=dropout_p, + training=training, + inplace=True, + ) out = attn_weight.to(v.dtype).matmul(v) out = rearrange(out, 'b h s d -> b s (h d)') @@ -195,8 +206,10 @@ def scaled_multihead_dot_product_attention( return out, None, past_key_value -def check_valid_inputs(*tensors: torch.Tensor, - valid_dtypes: Optional[list[torch.dtype]] = None): +def check_valid_inputs( + *tensors: torch.Tensor, + valid_dtypes: Optional[list[torch.dtype]] = None, +): if valid_dtypes is None: valid_dtypes = [torch.float16, torch.bfloat16] for tensor in tensors: @@ -236,7 +249,8 @@ def flash_attn_fn( from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: raise RuntimeError( - 'Please install flash-attn==1.0.9 or flash-attn==2.3.6') + 'Please install flash-attn==1.0.9 or flash-attn==2.3.6', + ) check_valid_inputs(query, key, value) @@ -261,21 +275,27 @@ def flash_attn_fn( max_seqlen_k = flash_attn_padding_info['max_seqlen_k'] query_unpad = bert_padding.index_first_axis( - rearrange(query, 'b s ... -> (b s) ...'), indices_q) + rearrange(query, 'b s ... -> (b s) ...'), + indices_q, + ) query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) key_unpad = bert_padding.index_first_axis( - rearrange(key, 'b s ... -> (b s) ...'), indices_k) + rearrange(key, 'b s ... -> (b s) ...'), + indices_k, + ) key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) value_unpad = bert_padding.index_first_axis( - rearrange(value, 'b s ... -> (b s) ...'), indices_v) + rearrange(value, 'b s ... -> (b s) ...'), + indices_v, + ) value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( - not should_repeat_kv_for_gqa): + if (kv_n_heads < n_heads) and (not is_flash_v2_installed() + ) and (not should_repeat_kv_for_gqa): raise ValueError( - 'For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.' + 'For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.', ) if should_repeat_kv_for_gqa: @@ -287,10 +307,16 @@ def flash_attn_fn( # - pytorch docs # # hopefully the kernels can utilize this and we're jot just wasting BW here - key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, - key_unpad.size(-1)) - value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, - value_unpad.size(-1)) + key_unpad = key_unpad.expand( + key_unpad.size(0), + n_heads, + key_unpad.size(-1), + ) + value_unpad = value_unpad.expand( + value_unpad.size(0), + n_heads, + value_unpad.size(-1), + ) # grouped query case elif kv_n_heads < n_heads: # Each query belong to a group of kv heads of group size n_heads // kv_n_heads @@ -301,10 +327,12 @@ def flash_attn_fn( key_unpad = repeat_kv_for_gqa( key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1), - n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1) + n_heads // kv_n_heads, + ).view(key_unpad.size(0), n_heads, -1) value_unpad = repeat_kv_for_gqa( value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1), - n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1) + n_heads // kv_n_heads, + ).view(value_unpad.size(0), n_heads, -1) dropout_p = dropout_p if training else 0.0 @@ -322,14 +350,16 @@ def flash_attn_fn( dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, - return_attn_probs=needs_weights) + return_attn_probs=needs_weights, + ) elif is_flash_v2_installed(): alibi_kwargs = {} if check_alibi_support('flash'): alibi_kwargs = {'alibi_slopes': alibi_slopes} elif alibi_slopes is not None: raise ValueError( - 'alibi_slopes is only supported for flash-attn>=2.4.2') + 'alibi_slopes is only supported for flash-attn>=2.4.2', + ) output_unpad = flash_attn_interface.flash_attn_varlen_func( q=query_unpad, k=key_unpad, @@ -343,14 +373,19 @@ def flash_attn_fn( causal=reset_is_causal, return_attn_probs=needs_weights, window_size=(sliding_window_size, sliding_window_size), - **alibi_kwargs) + **alibi_kwargs, + ) else: raise RuntimeError( - 'flash-attn==1.0.9 or flash-attn==2.4.2 is required.') + 'flash-attn==1.0.9 or flash-attn==2.4.2 is required.', + ) output = bert_padding.pad_input( - rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, - seqlen) + rearrange(output_unpad, 'nnz h d -> nnz (h d)'), + indices_q, + batch_size, + seqlen, + ) return output, None, past_key_value @@ -401,12 +436,12 @@ def __init__( if self.kv_n_heads > self.n_heads: raise ValueError( - 'The number of KV heads should be less than or equal to Q heads.' + 'The number of KV heads should be less than or equal to Q heads.', ) if self.n_heads % self.kv_n_heads != 0: raise ValueError( - 'Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.' + 'Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.', ) if qk_ln and qk_gn: raise ValueError('Only one of qk_ln and qk_gn can be set to True.') @@ -470,7 +505,7 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ - torch.Tensor, torch.Tensor]]]: + torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) if self.clip_qkv: @@ -510,10 +545,12 @@ def forward( value = value.view(bsz, seqlen, -1, self.head_dim) kv = torch.stack([key, value], dim=2) - query, kv = rotary_emb(query, - kv, - seqlen_offset=offset_info, - max_seqlen=seq_len) + query, kv = rotary_emb( + query, + kv, + seqlen_offset=offset_info, + max_seqlen=seq_len, + ) [key, value] = torch.unbind(kv, dim=2) value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim) @@ -526,27 +563,33 @@ def forward( else: (cos, sin) = rotary_emb(x=value, seq_len=seq_len) if is_transformers_version_gte('4.38'): - query, key = apply_rotary_pos_emb(q=query, - k=key, - cos=cos, - sin=sin, - position_ids=None, - unsqueeze_dim=2) + query, key = apply_rotary_pos_emb( + q=query, + k=key, + cos=cos, + sin=sin, + position_ids=None, + unsqueeze_dim=2, + ) elif is_transformers_version_gte('4.36'): - query, key = apply_rotary_pos_emb(q=query, - k=key, - cos=cos, - sin=sin, - position_ids=offset_info, - unsqueeze_dim=2) + query, key = apply_rotary_pos_emb( + q=query, + k=key, + cos=cos, + sin=sin, + position_ids=offset_info, + unsqueeze_dim=2, + ) else: query = query.transpose(1, 2) key = key.transpose(1, 2) - query, key = apply_rotary_pos_emb(q=query, - k=key, - cos=cos, - sin=sin, - position_ids=offset_info) + query, key = apply_rotary_pos_emb( + q=query, + k=key, + cos=cos, + sin=sin, + position_ids=offset_info, + ) query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -666,8 +709,13 @@ def __init__( def attn_bias_shape( - attn_impl: str, n_heads: int, seq_len: int, alibi: bool, causal: bool, - use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]: + attn_impl: str, + n_heads: int, + seq_len: int, + alibi: bool, + causal: bool, + use_sequence_id: bool, +) -> Optional[tuple[int, int, int, int]]: if attn_impl == 'flash': return None elif attn_impl == 'torch': @@ -705,16 +753,19 @@ def build_attn_bias( alibi_bias_max=alibi_bias_max, device=device, dtype=dtype, - )) + ), + ) return attn_bias else: raise ValueError(f'{attn_impl=} is an invalid setting.') -def gen_slopes(n_heads: int, - alibi_bias_max: int = 8, - device: Optional[torch.device] = None, - return_1d: bool = False) -> torch.Tensor: +def gen_slopes( + n_heads: int, + alibi_bias_max: int = 8, + device: Optional[torch.device] = None, + return_1d: bool = False, +) -> torch.Tensor: _n_heads = 2**math.ceil(math.log2(n_heads)) m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) m = m.mul(alibi_bias_max / _n_heads) @@ -744,8 +795,11 @@ def build_alibi_bias( # generate 1 x Heads x SeqLen x SeqLen alibi bias mask # otherwise the mask is 1 x Heads x 1 x SeqLen (which is broadcast to the appropriate size) alibi_bias = alibi_bias - torch.arange( - 1 - seq_len, 1, dtype=torch.int32, device=device).view( - 1, 1, seq_len, 1) + 1 - seq_len, + 1, + dtype=torch.int32, + device=device, + ).view(1, 1, seq_len, 1) alibi_bias = alibi_bias.abs().mul(-1) slopes = gen_slopes(n_heads, alibi_bias_max, device=device) @@ -754,5 +808,7 @@ def build_alibi_bias( attention_implementations.register('flash', func=flash_attn_fn) -attention_implementations.register('torch', - func=scaled_multihead_dot_product_attention) +attention_implementations.register( + 'torch', + func=scaled_multihead_dot_product_attention, +) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index d56c4753af..494bdcdff1 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -9,8 +9,11 @@ import torch.nn as nn from llmfoundry.layers_registry import ffns_with_norm -from llmfoundry.models.layers.layer_builders import (build_attention_layer, - build_ffn, build_norm) +from llmfoundry.models.layers.layer_builders import ( + build_attention_layer, + build_ffn, + build_norm, +) try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip @@ -97,9 +100,15 @@ def __init__( assert isinstance(attn_config['attn_type'], str) # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { - 'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max', - 'rope', 'rope_theta', 'rope_impl', 'rope_dail_config', - 'rope_hf_config' + 'attn_type', + 'alibi', + 'attn_uses_sequence_id', + 'alibi_bias_max', + 'rope', + 'rope_theta', + 'rope_impl', + 'rope_dail_config', + 'rope_hf_config', } attn_config_subset_for_attn_class = { k: v @@ -120,7 +129,7 @@ def __init__( 'fc_type': fc_type, 'device': device, 'bias': not no_bias, - **attn_config_subset_for_attn_class + **attn_config_subset_for_attn_class, }, ) self.norm_2 = None @@ -156,7 +165,7 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ - torch.Tensor, torch.Tensor]]]: + torch.Tensor, torch.Tensor]]]: if self.fuse_norm_attn_norm: x, m, attn_weights, past_key_value = self.norm_attn_norm( x, @@ -221,9 +230,15 @@ def __init__( # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { - 'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max', - 'rope', 'rope_theta', 'rope_impl', 'rope_dail_config', - 'rope_hf_config' + 'attn_type', + 'alibi', + 'attn_uses_sequence_id', + 'alibi_bias_max', + 'rope', + 'rope_theta', + 'rope_impl', + 'rope_dail_config', + 'rope_hf_config', } attn_config_subset_for_attn_class = { k: v @@ -243,7 +258,7 @@ def __init__( 'fc_type': fc_type, 'device': device, 'bias': not no_bias, - **attn_config_subset_for_attn_class + **attn_config_subset_for_attn_class, }, ) diff --git a/llmfoundry/models/layers/dmoe.py b/llmfoundry/models/layers/dmoe.py index f2b255294c..e467ce227f 100644 --- a/llmfoundry/models/layers/dmoe.py +++ b/llmfoundry/models/layers/dmoe.py @@ -20,9 +20,10 @@ class _UniformExpertAssignment(torch.autograd.Function): @staticmethod def forward( - ctx, # pyright: ignore[reportMissingParameterType] - x: torch.Tensor, - num_experts: int): + ctx, # pyright: ignore[reportMissingParameterType] + x: torch.Tensor, + num_experts: int, + ): out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) out = torch.remainder(out, num_experts) return out.view(x.shape) @@ -30,10 +31,16 @@ def forward( class LearnedRouter(torch.nn.Module): - def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, - moe_jitter_eps: float, moe_normalize_expert_weights: bool, - uniform_expert_assignment: bool, - device: Optional[torch.device]) -> None: + def __init__( + self, + hidden_size: int, + moe_num_experts: int, + moe_top_k: int, + moe_jitter_eps: float, + moe_normalize_expert_weights: bool, + uniform_expert_assignment: bool, + device: Optional[torch.device], + ) -> None: super().__init__() self.hidden_size: int = hidden_size self.moe_num_experts: int = moe_num_experts @@ -52,17 +59,23 @@ def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, def jitter(self, x: torch.Tensor) -> torch.Tensor: low: float = 1.0 - self.moe_jitter_eps high: float = 1.0 + self.moe_jitter_eps - noise: torch.Tensor = torch.rand(x.size(), - dtype=x.dtype, - device=x.device) + noise: torch.Tensor = torch.rand( + x.size(), + dtype=x.dtype, + device=x.device, + ) return low + noise * (high - low) def _top_k(self, scores: torch.Tensor) -> torch.Tensor: if self.moe_top_k == 1: return scores.max( - dim=-1) # pyright: ignore[reportGeneralTypeIssues] - return torch.topk(scores, self.moe_top_k, - dim=-1) # pyright: ignore[reportGeneralTypeIssues] + dim=-1, + ) # pyright: ignore[reportGeneralTypeIssues] + return torch.topk( + scores, + self.moe_top_k, + dim=-1, + ) # pyright: ignore[reportGeneralTypeIssues] def forward(self, x: torch.Tensor): if self.training and self.moe_jitter_eps is not None: @@ -75,11 +88,13 @@ def forward(self, x: torch.Tensor): expert_weights, p=self.moe_normalize_expert_weights, dim=-1, - keepdim=True) + keepdim=True, + ) - top_experts = (_UniformExpertAssignment.apply(top_experts, - self.moe_num_experts) - if self.uniform_expert_assignment else top_experts) + top_experts = ( + _UniformExpertAssignment.apply(top_experts, self.moe_num_experts) + if self.uniform_expert_assignment else top_experts + ) scores = scores.to(x.dtype) expert_weights = expert_weights.to(x.dtype) return scores, expert_weights, top_experts @@ -103,20 +118,32 @@ def __init__( self.activation_fn: Callable = activation_fn self.w1 = torch.nn.Parameter( - torch.rand(moe_num_experts * ffn_hidden_size, - hidden_size, - device=device)) + torch.rand( + moe_num_experts * ffn_hidden_size, + hidden_size, + device=device, + ), + ) self.w2 = torch.nn.Parameter( - torch.rand(moe_num_experts * ffn_hidden_size, - hidden_size, - device=device)) + torch.rand( + moe_num_experts * ffn_hidden_size, + hidden_size, + device=device, + ), + ) self.activation_fn = activation_fn def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: - expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, - self.hidden_size)[expert_idx] - expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, - self.hidden_size)[expert_idx] + expert_w1 = self.w1.view( + self.moe_num_experts, + self.ffn_hidden_size, + self.hidden_size, + )[expert_idx] + expert_w2 = self.w2.view( + self.moe_num_experts, + self.ffn_hidden_size, + self.hidden_size, + )[expert_idx] before_activation = x @ expert_w1.t() layer_1_output = self.activation_fn(before_activation) @@ -140,26 +167,44 @@ def __init__( self.moe_num_experts = moe_num_experts self.w1 = torch.nn.Parameter( - torch.rand(moe_num_experts * ffn_hidden_size, - hidden_size, - device=device)) + torch.rand( + moe_num_experts * ffn_hidden_size, + hidden_size, + device=device, + ), + ) self.v1 = torch.nn.Parameter( - torch.rand(moe_num_experts * ffn_hidden_size, - hidden_size, - device=device)) + torch.rand( + moe_num_experts * ffn_hidden_size, + hidden_size, + device=device, + ), + ) self.w2 = torch.nn.Parameter( - torch.rand(moe_num_experts * ffn_hidden_size, - hidden_size, - device=device)) + torch.rand( + moe_num_experts * ffn_hidden_size, + hidden_size, + device=device, + ), + ) self.activation_fn = activation_fn def forward(self, x: torch.Tensor, expert_idx: torch.Tensor): - expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, - self.hidden_size)[expert_idx] - expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, - self.hidden_size)[expert_idx] - expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, - self.hidden_size)[expert_idx] + expert_w1 = self.w1.view( + self.moe_num_experts, + self.ffn_hidden_size, + self.hidden_size, + )[expert_idx] + expert_v1 = self.v1.view( + self.moe_num_experts, + self.ffn_hidden_size, + self.hidden_size, + )[expert_idx] + expert_w2 = self.w2.view( + self.moe_num_experts, + self.ffn_hidden_size, + self.hidden_size, + )[expert_idx] x1 = x.matmul(expert_w1.t()) x2 = x.matmul(expert_v1.t()) @@ -185,22 +230,31 @@ def __init__( self.moe_num_experts = moe_num_experts if mlp_type == 'mlp': - self.mlp = MLP(hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=moe_num_experts, - activation_fn=activation_fn, - device=device) + self.mlp = MLP( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=moe_num_experts, + activation_fn=activation_fn, + device=device, + ) elif mlp_type == 'glu': - self.mlp = GLU(hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=moe_num_experts, - activation_fn=activation_fn, - device=device) + self.mlp = GLU( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=moe_num_experts, + activation_fn=activation_fn, + device=device, + ) else: raise ValueError(f'Received unknown {mlp_type=}') - def forward(self, x: torch.Tensor, scores: torch.Tensor, - expert_weights: torch.Tensor, top_experts: torch.Tensor): + def forward( + self, + x: torch.Tensor, + scores: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + ): in_shape = x.shape hidden_size = in_shape[-1] @@ -208,7 +262,9 @@ def forward(self, x: torch.Tensor, scores: torch.Tensor, out = torch.zeros_like(x) expert_mask = torch.nn.functional.one_hot( - top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) + top_experts, + num_classes=self.moe_num_experts, + ).permute(2, 1, 0) for expert_idx in range(0, self.moe_num_experts): topk_idx, token_idx = torch.where(expert_mask[expert_idx]) if token_idx.shape[0] == 0: diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index c64e87cb9a..3542da7c4b 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -13,8 +13,11 @@ from torch.distributed import ProcessGroup from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard -from llmfoundry.layers_registry import (ffns, ffns_with_megablocks, - ffns_with_norm) +from llmfoundry.layers_registry import ( + ffns, + ffns_with_megablocks, + ffns_with_norm, +) from llmfoundry.models.layers.dmoe import dMoE from llmfoundry.models.layers.layer_builders import build_fc @@ -50,7 +53,8 @@ def resolve_ffn_act_fn( - config: Optional[dict] = None,) -> Callable[[torch.Tensor], torch.Tensor]: + config: Optional[dict] = None, +) -> Callable[[torch.Tensor], torch.Tensor]: """Resolve the activation function for the feed-forward network. Args: @@ -91,19 +95,22 @@ def resolve_ffn_hidden_size( """ if ffn_hidden_size is not None: log.info( - f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.' + f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.', ) else: ffn_hidden_size = int(d_model * expansion_ratio) if ffn_hidden_size != d_model * expansion_ratio: raise ValueError( - f'`d_model * expansion_ratio` must be an integer ({d_model=}; {expansion_ratio=}; {d_model * expansion_ratio=}).' + f'`d_model * expansion_ratio` must be an integer ({d_model=}; {expansion_ratio=}; {d_model * expansion_ratio=}).', ) return ffn_hidden_size -def dtensorify_param(param: nn.Parameter, mesh: DeviceMesh, - placements: List[Placement]): +def dtensorify_param( + param: nn.Parameter, + mesh: DeviceMesh, + placements: List[Placement], +): """Construct a DTensor from an already sharded local parameter.""" param_dtensor = DTensor.from_local( param.data, @@ -127,8 +134,11 @@ def __init__( bias: bool = True, ): super().__init__() - ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size) + ffn_hidden_size = resolve_ffn_hidden_size( + d_model, + expansion_ratio, + ffn_hidden_size, + ) self.fc_kwargs: dict[str, Any] = { 'bias': bias, } @@ -237,11 +247,14 @@ def build_te_ln_mlp( **kwargs: Any, ) -> nn.Module: assert te is not None - ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size) + ffn_hidden_size = resolve_ffn_hidden_size( + d_model, + expansion_ratio, + ffn_hidden_size, + ) if ffn_act_fn is not None: raise ValueError( - f'Transformer Engine block does not support custom activation functions.' + f'Transformer Engine block does not support custom activation functions.', ) return te.LayerNormMLP( hidden_size=d_model, @@ -275,8 +288,11 @@ def build_torch_dmoe( return dMoE( hidden_size=d_model, - ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size), + ffn_hidden_size=resolve_ffn_hidden_size( + d_model, + expansion_ratio, + ffn_hidden_size, + ), moe_num_experts=moe_num_experts, moe_top_k=moe_top_k, mlp_type=mlp_type, @@ -300,15 +316,18 @@ def _mb_setup_args( ) -> tuple['megablocks.layers.arguments.Arguments', int, ProcessGroup]: if megablocks is None: raise RuntimeError( - 'Requirements for megablocks not installed; see install instructions in `README.md`.' + 'Requirements for megablocks not installed; see install instructions in `README.md`.', ) args = kwargs['args'] args.bias = bias args.hidden_size = d_model args.device = device - ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, - ffn_hidden_size) + ffn_hidden_size = resolve_ffn_hidden_size( + d_model, + expansion_ratio, + ffn_hidden_size, + ) args.ffn_hidden_size = ffn_hidden_size if ffn_act_fn is not None: @@ -320,7 +339,8 @@ def _mb_setup_args( moe_world_size = expert_parallel_group.size() if kwargs.get('moe_world_size') != moe_world_size: raise RuntimeError( - f'MoE expert_parallel_group configured with incorrect world size.') + f'MoE expert_parallel_group configured with incorrect world size.', + ) return args, moe_world_size, expert_parallel_group @@ -341,13 +361,14 @@ def _patch_ffn_mb( expert_mesh = device_mesh['expert_parallel'] expert_placements: List[Placement] = [Shard(0)] # Register in two loops as you cannot overwrite parameters while iterating over named_parameters() - dtensorified_params = [ - (name, - dtensorify_param(param=parameter, - mesh=expert_mesh, - placements=expert_placements)) - for name, parameter in ffn.experts.mlp.named_parameters() - ] + dtensorified_params = [( + name, + dtensorify_param( + param=parameter, + mesh=expert_mesh, + placements=expert_placements, + ), + ) for name, parameter in ffn.experts.mlp.named_parameters()] for name, dtensorified_param in dtensorified_params: ffn.experts.mlp.register_parameter(name, dtensorified_param) @@ -374,7 +395,7 @@ def build_mb_moe( ) -> nn.Module: if not is_megablocks_imported: raise RuntimeError( - 'Requirements for megablocks not installed; see install instructions in `README.md`.' + 'Requirements for megablocks not installed; see install instructions in `README.md`.', ) args, moe_world_size, expert_parallel_group = _mb_setup_args( @@ -415,7 +436,7 @@ def build_mb_dmoe( ) -> nn.Module: if not is_megablocks_imported: raise RuntimeError( - 'Requirements for megablocks not installed; see install instructions in `README.md`.' + 'Requirements for megablocks not installed; see install instructions in `README.md`.', ) args, moe_world_size, expert_parallel_group = _mb_setup_args( @@ -433,9 +454,10 @@ def build_mb_dmoe( # Fused initialization setup # For param_init_fn, enables shape based init of fused layers n_exp = min(1, args.moe_num_experts // moe_world_size) - ffn.experts.mlp._fused = (0, [ - (n + 1) * args.ffn_hidden_size for n in range(n_exp - 1) - ]) + ffn.experts.mlp._fused = ( + 0, + [(n + 1) * args.ffn_hidden_size for n in range(n_exp - 1)], + ) _patch_ffn_mb( ffn=ffn, diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index ceb41d8d41..f6c8ce3533 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -5,9 +5,14 @@ import torch -from llmfoundry.layers_registry import (attention_classes, fcs, ffns, - ffns_with_megablocks, ffns_with_norm, - norms) +from llmfoundry.layers_registry import ( + attention_classes, + fcs, + ffns, + ffns_with_megablocks, + ffns_with_norm, + norms, +) from llmfoundry.utils.registry_utils import construct_from_registry __all__ = [ @@ -28,10 +33,12 @@ def build_norm( 'device': device, } - return construct_from_registry(name=name, - registry=norms, - pre_validation_function=torch.nn.Module, - kwargs=kwargs) + return construct_from_registry( + name=name, + registry=norms, + pre_validation_function=torch.nn.Module, + kwargs=kwargs, + ) def build_ffn( @@ -67,7 +74,8 @@ def _validation_function(maybe_module: Any): registry=registry_to_use, post_validation_function=_validation_function, partial_function=False, - kwargs=kwargs) + kwargs=kwargs, + ) if name in ffns_with_norm: result._has_norm = True @@ -82,10 +90,12 @@ def build_attention_layer( name: str, attn_kwargs: Dict[str, Any], ): - return construct_from_registry(name=name, - registry=attention_classes, - pre_validation_function=torch.nn.Module, - kwargs=attn_kwargs) + return construct_from_registry( + name=name, + registry=attention_classes, + pre_validation_function=torch.nn.Module, + kwargs=attn_kwargs, + ) def build_fc( @@ -100,7 +110,9 @@ def build_fc( **fc_kwargs, } - return construct_from_registry(name=name, - registry=fcs, - pre_validation_function=torch.nn.Module, - kwargs=kwargs) + return construct_from_registry( + name=name, + registry=fcs, + pre_validation_function=torch.nn.Module, + kwargs=kwargs, + ) diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index 23b92015e7..c853f5fd26 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -53,9 +53,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: module_device = x.device downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled( - self.weight) if self.weight is not None else self.weight + self.weight, + ) if self.weight is not None else self.weight downcast_bias = _cast_if_autocast_enabled( - self.bias) if self.bias is not None else self.bias + self.bias, + ) if self.bias is not None else self.bias with torch.autocast(enabled=False, device_type=module_device.type): return torch.nn.functional.layer_norm( downcast_x, @@ -66,9 +68,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) -def rms_norm(x: torch.Tensor, - weight: Optional[torch.Tensor] = None, - eps: float = 1e-5) -> torch.Tensor: +def rms_norm( + x: torch.Tensor, + weight: Optional[torch.Tensor] = None, + eps: float = 1e-5, +) -> torch.Tensor: output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) if weight is not None: return output * weight @@ -90,7 +94,8 @@ def __init__( self.eps = eps if weight: self.weight = torch.nn.Parameter( - torch.ones(normalized_shape, dtype=dtype, device=device)) + torch.ones(normalized_shape, dtype=dtype, device=device), + ) else: self.register_parameter('weight', None) @@ -120,7 +125,8 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled( - self.weight) if self.weight is not None else self.weight + self.weight, + ) if self.weight is not None else self.weight with torch.autocast(enabled=False, device_type=x.device.type): return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) @@ -144,7 +150,8 @@ def __init__( except ImportError: raise ImportError( 'triton_rms_norm requires Flash Attention to be installed. ' + - 'Please pip install flash-attn.') + 'Please pip install flash-attn.', + ) if not isinstance(normalized_shape, int): raise ValueError('TritonRMSNorm only supports 1D tensors') @@ -152,7 +159,8 @@ def __init__( self.rms_norm_fn = rms_norm_fn self.weight = torch.nn.Parameter( - torch.ones(normalized_shape, device=device, dtype=dtype)) + torch.ones(normalized_shape, device=device, dtype=dtype), + ) def forward(self, x: torch.Tensor): # Flash Attention expect a flat tensor diff --git a/llmfoundry/models/mpt/__init__.py b/llmfoundry/models/mpt/__init__.py index 04bed25fab..688fb1f558 100644 --- a/llmfoundry/models/mpt/__init__.py +++ b/llmfoundry/models/mpt/__init__.py @@ -2,9 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.mpt.configuration_mpt import MPTConfig -from llmfoundry.models.mpt.modeling_mpt import (ComposerMPTCausalLM, - MPTForCausalLM, MPTModel, - MPTPreTrainedModel) +from llmfoundry.models.mpt.modeling_mpt import ( + ComposerMPTCausalLM, + MPTForCausalLM, + MPTModel, + MPTPreTrainedModel, +) __all__ = [ 'MPTPreTrainedModel', diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index dbee232f3d..78653fabdc 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -9,8 +9,10 @@ from transformers import PretrainedConfig from llmfoundry.layers_registry import ffns_with_megablocks -from llmfoundry.models.layers.attention import (check_alibi_support, - is_flash_v2_installed) +from llmfoundry.models.layers.attention import ( + check_alibi_support, + is_flash_v2_installed, +) from llmfoundry.models.layers.blocks import attn_config_defaults # NOTE: All utils are imported directly even if unused so that @@ -159,11 +161,11 @@ def __init__( del kwargs['name'] if 'loss_fn' in kwargs: del kwargs['loss_fn'] - if self.attn_config.get('alibi', False) or self.attn_config.get( - 'rope', False): + if self.attn_config.get('alibi', + False) or self.attn_config.get('rope', False): self.learned_pos_emb = False warnings.warn( - f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`' + f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`', ) # tie_word_embeddings is set in Huggingface's PretrainedConfig __init__ super().__init__( @@ -173,8 +175,11 @@ def __init__( self._validate_config() - def _set_config_defaults(self, config: Dict[str, Any], - config_defaults: Dict[str, Any]) -> Dict[str, Any]: + def _set_config_defaults( + self, + config: Dict[str, Any], + config_defaults: Dict[str, Any], + ) -> Dict[str, Any]: # set config defaults for k, v in config_defaults.items(): if k not in config: @@ -182,7 +187,9 @@ def _set_config_defaults(self, config: Dict[str, Any], elif isinstance(v, dict): # recursively set default values for any sub-dicts config[k] = self._set_config_defaults( - config[k] if (config[k] is not None) else {}, v) + config[k] if (config[k] is not None) else {}, + v, + ) return config def _validate_config(self) -> None: @@ -203,72 +210,87 @@ def _validate_config(self) -> None: if self.d_model % self.n_heads != 0: raise ValueError('d_model must be divisible by n_heads') if any( - prob < 0 or prob > 1 for prob in - [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop]): + prob < 0 or prob > 1 for prob in + [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop] + ): raise ValueError( - "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1" + "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1", ) if self.attn_config['attn_impl'] not in ['torch', 'flash']: raise ValueError( - f"Unknown attn_impl={self.attn_config['attn_impl']}") + f"Unknown attn_impl={self.attn_config['attn_impl']}", + ) if self.attn_config['alibi'] and not check_alibi_support( - self.attn_config['attn_impl']): + self.attn_config['attn_impl'], + ): raise NotImplementedError( - 'alibi only implemented with torch and flash (v2.4.2 or higher) attention.' + 'alibi only implemented with torch and flash (v2.4.2 or higher) attention.', ) if self.attn_config['attn_uses_sequence_id'] and not ( - self.attn_config['attn_impl'] == 'torch' or - (self.attn_config['attn_impl'] == 'flash' and - is_flash_v2_installed(v2_version='v2.1.2'))): + self.attn_config['attn_impl'] == 'torch' or ( + self.attn_config['attn_impl'] == 'flash' and + is_flash_v2_installed(v2_version='v2.1.2') + ) + ): raise NotImplementedError( - 'attn_uses_sequence_id only implemented with torch and flash (v2.1.2 or higher) attention.' + 'attn_uses_sequence_id only implemented with torch and flash (v2.1.2 or higher) attention.', ) - if self.attn_config['rope'] and (self.attn_config['rope_impl'] - not in ['dail', 'hf']): + if self.attn_config['rope'] and ( + self.attn_config['rope_impl'] not in ['dail', 'hf'] + ): raise ValueError( - 'If rope is being used then rope_impl should be either "dail", or "hf".' + 'If rope is being used then rope_impl should be either "dail", or "hf".', ) if self.attn_config['rope'] and ( - self.attn_config['rope_impl'] - == 'hf') and self.attn_config['rope_hf_config']['type'] not in [ - 'no_scaling', 'linear', 'dynamic' - ]: + self.attn_config['rope_impl'] == 'hf' + ) and self.attn_config['rope_hf_config']['type'] not in [ + 'no_scaling', + 'linear', + 'dynamic', + ]: raise ValueError( - 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".' + 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".', ) - if self.attn_config['rope'] and (self.attn_config['rope_impl'] - == 'dail'): + if self.attn_config['rope'] and ( + self.attn_config['rope_impl'] == 'dail' + ): if self.attn_config['rope_dail_config']['type'] not in [ - 'original', 'xpos' + 'original', + 'xpos', ]: raise ValueError( - 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' + 'If using the dail implementation of rope, the type should be one of "original" or "xpos".', ) if not is_flash_v2_installed(v2_version='2.0.1'): raise ImportError( - 'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support' + 'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support', ) if self.attn_config['sliding_window_size'] != -1 and not ( - self.attn_config['attn_impl'] == 'flash' and - is_flash_v2_installed(v2_version='v2.3.0')): + self.attn_config['attn_impl'] == 'flash' and + is_flash_v2_installed(v2_version='v2.3.0') + ): raise NotImplementedError( - 'sliding window only implemented with flash attention v2.3.0 or higher.' + 'sliding window only implemented with flash attention v2.3.0 or higher.', ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( - 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' + 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!', ) - if isinstance(self.logit_scale, - str) and self.logit_scale != 'inv_sqrt_d_model': + if isinstance( + self.logit_scale, + str, + ) and self.logit_scale != 'inv_sqrt_d_model': raise ValueError( - f"{self.logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." + f"{self.logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.", ) if self.init_config.get('name', None) is None: raise ValueError(f"{self.init_config=} 'name' needs to be set.") - if not (self.learned_pos_emb or self.attn_config['alibi'] or - self.attn_config['rope']): + if not ( + self.learned_pos_emb or self.attn_config['alibi'] or + self.attn_config['rope'] + ): warnings.warn( - f'Positional information not being provided to the model using either learned_pos_emb or alibi or rope.' + f'Positional information not being provided to the model using either learned_pos_emb or alibi or rope.', ) if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp': try: @@ -280,13 +302,13 @@ def _validate_config(self) -> None: + 'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n' + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + - 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156', ) if self.ffn_config['ffn_type'] == 'mptgeglu': raise ValueError( 'API CHANGE: `ffn_type=="mptgeglu"` changed to `ffn_type=="mptglu"`. ' + - 'See [#829](https://github.com/mosaicml/llm-foundry/pull/829) for details.' + 'See [#829](https://github.com/mosaicml/llm-foundry/pull/829) for details.', ) elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']: self.ffn_config['fc_type'] = self.fc_type @@ -296,12 +318,12 @@ def _validate_config(self) -> None: self.ffn_config['bias'] = not self.no_bias if 'ffn_act_fn' in self.ffn_config.keys(): raise ValueError( - f'Transformer Engine block does not support custom activation functions.' + f'Transformer Engine block does not support custom activation functions.', ) if not self.use_pad_tok_in_ffn: try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip except: raise ImportError( - 'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6' + 'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6', ) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 1ef62a3b19..bdf6cff925 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -11,8 +11,16 @@ import math import warnings from functools import cached_property -from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, - Union) +from typing import ( + Any, + Dict, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, +) import torch import torch.nn as nn @@ -34,8 +42,10 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase -from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) from transformers.models.llama.modeling_llama import \ LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding from transformers.models.llama.modeling_llama import \ @@ -44,15 +54,20 @@ LlamaRotaryEmbedding as HFRotaryEmbedding from llmfoundry.layers_registry import norms, param_init_fns -from llmfoundry.models.layers.attention import (attn_bias_shape, - build_attn_bias, gen_slopes) +from llmfoundry.models.layers.attention import ( + attn_bias_shape, + build_attn_bias, + gen_slopes, +) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.layer_builders import build_norm from llmfoundry.models.mpt.configuration_mpt import MPTConfig from llmfoundry.models.utils.config_moe_args import config_moe_args -from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params, - mpt_get_total_params) +from llmfoundry.models.utils.mpt_param_count import ( + mpt_get_active_params, + mpt_get_total_params, +) # NOTE: All utils are imported directly even if unused so that # HuggingFace can detect all the needed files to copy into its modules folder. @@ -65,18 +80,25 @@ ) from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note) -from llmfoundry.models.utils.act_ckpt import (pass_on_block_idx, - build_act_ckpt_mod_to_blocks, - check_mapping_blocks_overlap) +from llmfoundry.models.utils.act_ckpt import ( + pass_on_block_idx, + build_act_ckpt_mod_to_blocks, + check_mapping_blocks_overlap, +) import logging log = logging.getLogger(__name__) -def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, - rope_dail_config: dict, rope_hf_config: dict, - max_seq_len: int): +def gen_rotary_embedding( + rope_head_dim: int, + rope_impl: str, + rope_theta: int, + rope_dail_config: dict, + rope_hf_config: dict, + max_seq_len: int, +): if rope_impl == 'dail': return DAILRotaryEmbedding( dim=rope_head_dim, @@ -95,7 +117,7 @@ def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, max_position_embeddings=max_seq_len, base=rope_theta, device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif rope_hf_config['type'] == 'linear': return HFLinearScalingRotaryEmbedding( @@ -104,7 +126,7 @@ def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, base=rope_theta, scaling_factor=rope_hf_config['factor'], device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif rope_hf_config['type'] == 'dynamic': return HFDynamicNTKScalingRotaryEmbedding( @@ -113,14 +135,18 @@ def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, base=rope_theta, scaling_factor=rope_hf_config['factor'], device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) raise ValueError('rope_impl needs to be either dail or hf') -def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, - attn_uses_sequence_id: bool, attn_impl: str, - attention_mask: Union[torch.Tensor, None]): +def gen_attention_mask_in_length( + sequence_id: Union[None, torch.Tensor], + S: int, + attn_uses_sequence_id: bool, + attn_impl: str, + attention_mask: Union[torch.Tensor, None], +): """Generates the attention mask used for sequence masking in FA v2. Only supports sequence id based sparse attention for no attention masking or attention masking with right padding. @@ -176,17 +202,17 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, (The description above is taken verbatim from https://github.com/Dao-AILab/flash-attention/blob/9356a1c0389660d7e231ff3163c1ac17d9e3824a/flash_attn/bert_padding.py#L125 .) """ attention_mask_in_length = None - if (sequence_id is not None) and attn_uses_sequence_id and (attn_impl - == 'flash'): + if (sequence_id + is not None) and attn_uses_sequence_id and (attn_impl == 'flash'): # Check if sequence has left padding. If yes, raise an error. - if (attention_mask is not None) and (attention_mask[:, 0].sum() != - attention_mask.shape[0]): + if (attention_mask is not None + ) and (attention_mask[:, 0].sum() != attention_mask.shape[0]): raise NotImplementedError( - 'Left padding is not supported with flash attention when attn_uses_sequence_id is set to True.' + 'Left padding is not supported with flash attention when attn_uses_sequence_id is set to True.', ) if S != sequence_id.shape[-1]: raise ValueError( - f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).' + f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).', ) if attention_mask is not None: # -1 is used to pad the sequence_id where attention mask is False (https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249). @@ -196,24 +222,28 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, attention_mask_in_length = torch.nn.functional.one_hot(sequence_id) if attention_mask is not None: attention_mask_in_length = attention_mask_in_length.masked_fill( - ~attention_mask.unsqueeze(-1), 0) + ~attention_mask.unsqueeze(-1), + 0, + ) attention_mask_in_length = attention_mask_in_length.sum(dim=1) attention_mask_in_length = torch.nn.functional.pad( attention_mask_in_length, (0, S - attention_mask_in_length.shape[-1]), mode='constant', - value=0) + value=0, + ) return attention_mask_in_length def gen_flash_attn_padding_info( - bsz: int, - S: int, - past_key_len: int, - device: torch.device, - attention_mask_in_length: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None): + bsz: int, + S: int, + past_key_len: int, + device: torch.device, + attention_mask_in_length: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +): flash_attn_padding_info = {} if attention_mask_in_length is None: key_padding_mask = attention_mask @@ -229,11 +259,17 @@ def gen_flash_attn_padding_info( unpadding_function = bert_padding.unpad_input_for_concatenated_sequences _, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( - torch.empty(bsz, S, 1, device=device), query_padding_mask) + torch.empty(bsz, S, 1, device=device), + query_padding_mask, + ) _, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( - torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask) + torch.empty(bsz, past_key_len + S, 1, device=device), + key_padding_mask, + ) _, indices_v, _, _ = unpadding_function( - torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask) + torch.empty(bsz, past_key_len + S, 1, device=device), + key_padding_mask, + ) flash_attn_padding_info['indices_q'] = indices_q flash_attn_padding_info['indices_k'] = indices_k @@ -245,12 +281,15 @@ def gen_flash_attn_padding_info( return flash_attn_padding_info -def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor, - max_seq_len: int) -> torch.Tensor: +def apply_sequence_id( + attn_bias: torch.Tensor, + sequence_id: torch.LongTensor, + max_seq_len: int, +) -> torch.Tensor: seq_len = sequence_id.shape[-1] if seq_len > max_seq_len: raise ValueError( - f'sequence_id sequence length cannot exceed max_seq_len={max_seq_len}' + f'sequence_id sequence length cannot exceed max_seq_len={max_seq_len}', ) # select seq_len subset of attn mask @@ -262,7 +301,8 @@ def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor, torch.eq( sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len), - )).unsqueeze(1) + ), + ).unsqueeze(1) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) @@ -307,20 +347,24 @@ def __init__(self, config: MPTConfig): if config.norm_type.lower() not in norms.get_all(): norm_options = ' | '.join(norms.get_all()) raise NotImplementedError( - f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).' + f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).', ) # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414) # both report this helping with stabilizing training self.embedding_fraction = config.embedding_fraction - self.wte = SharedEmbedding(config.vocab_size, - config.d_model, - device=config.init_device) + self.wte = SharedEmbedding( + config.vocab_size, + config.d_model, + device=config.init_device, + ) if self.learned_pos_emb: - self.wpe = torch.nn.Embedding(config.max_seq_len, - config.d_model, - device=config.init_device) + self.wpe = torch.nn.Embedding( + config.max_seq_len, + config.d_model, + device=config.init_device, + ) self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None block_args = config.to_dict() @@ -362,11 +406,12 @@ def __init__(self, config: MPTConfig): rope_theta=config.attn_config['rope_theta'], rope_dail_config=config.attn_config['rope_dail_config'], rope_hf_config=config.attn_config['rope_hf_config'], - max_seq_len=self.config.max_seq_len) + max_seq_len=self.config.max_seq_len, + ) if config.init_device != 'meta': log.info( - f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.' + f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.', ) self.apply(self.param_init_fn) @@ -386,8 +431,8 @@ def __init__(self, config: MPTConfig): if config.no_bias: for module in self.modules(): - if hasattr(module, 'bias') and isinstance( - module.bias, nn.Parameter): + if hasattr(module, + 'bias') and isinstance(module.bias, nn.Parameter): log.debug(f'Removing bias from {module=}.') module.register_parameter('bias', None) @@ -403,7 +448,9 @@ def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: return self.wte def set_input_embeddings( - self, value: Union[SharedEmbedding, nn.Embedding]) -> None: + self, + value: Union[SharedEmbedding, nn.Embedding], + ) -> None: self.wte = value @torch.no_grad() @@ -416,9 +463,11 @@ def _attn_bias( ) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]: if not self._attn_bias_initialized: if self.attn_bias_shape: - self.attn_bias = torch.zeros(self.attn_bias_shape, - device=device, - dtype=dtype) + self.attn_bias = torch.zeros( + self.attn_bias_shape, + device=device, + dtype=dtype, + ) self.attn_bias = build_attn_bias( self.attn_impl, self.attn_bias, @@ -444,8 +493,11 @@ def _attn_bias( # If using torch, we incorporate sequence_id (if appropriate) if self.attn_uses_sequence_id and sequence_id is not None: assert isinstance(attn_bias, torch.Tensor) # pyright - attn_bias = apply_sequence_id(attn_bias, sequence_id, - self.config.max_seq_len) + attn_bias = apply_sequence_id( + attn_bias, + sequence_id, + self.config.max_seq_len, + ) # If using torch, we incorporate attention_mask. This will output # None in place of attention_mask since it will not be further needed in the @@ -462,7 +514,9 @@ def _attn_bias( attn_bias = attn_bias[:, :, :, _s_k:] min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill( - ~attention_mask.view(-1, 1, 1, s_k), min_val) + ~attention_mask.view(-1, 1, 1, s_k), + min_val, + ) return attn_bias, attention_mask @@ -478,10 +532,12 @@ def forward( use_cache: Optional[bool] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> BaseModelOutputWithPast: - return_dict = (return_dict - if return_dict is not None else self.config.return_dict) - use_cache = (use_cache - if use_cache is not None else self.config.use_cache) + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache + ) if attention_mask is not None: attention_mask = attention_mask.bool() # type: ignore @@ -491,34 +547,40 @@ def forward( # but have not yet been fully implemented in MPTModel if not return_dict: raise NotImplementedError( - 'return_dict False is not implemented yet for MPT') + 'return_dict False is not implemented yet for MPT', + ) if output_attentions: if self.attn_impl != 'torch': raise NotImplementedError( - 'output_attentions is not implemented for MPT when using attn_impl `flash`.' + 'output_attentions is not implemented for MPT when using attn_impl `flash`.', ) - if (self.training and attention_mask is not None and - attention_mask[:, 0].sum() != attention_mask.shape[0]): + if ( + self.training and attention_mask is not None and + attention_mask[:, 0].sum() != attention_mask.shape[0] + ): raise NotImplementedError( - 'MPT does not support training with left padding.') + 'MPT does not support training with left padding.', + ) if self.training: if self.attn_uses_sequence_id and sequence_id is None: raise ValueError( 'sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' - + 'and the model is in train mode.') - elif (self.attn_uses_sequence_id is False) and (sequence_id - is not None): + + 'and the model is in train mode.', + ) + elif (self.attn_uses_sequence_id is + False) and (sequence_id is not None): warnings.warn( 'MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + - 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.', ) if input_ids is not None and inputs_embeds is not None: raise ValueError( - 'You cannot specify both input_ids and inputs_embeds.') + 'You cannot specify both input_ids and inputs_embeds.', + ) elif input_ids is not None: bsz = input_ids.size(0) S = input_ids.size(1) @@ -544,7 +606,7 @@ def forward( raise ValueError( f'past_key_values must provide a past_key_value for each attention ' + - f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' + f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).', ) # For attn_impl: flash, the past key tensor spec is (batch, seq, dim). # For attn_impl: torch, the past key tensor spec is (batch, heads, head_dim, seq). @@ -554,12 +616,13 @@ def forward( past_position = past_key_values[0][0].size(3) if self.learned_pos_emb or self.rope: - if self.learned_pos_emb and (S + past_position > - self.config.max_seq_len): + if self.learned_pos_emb and ( + S + past_position > self.config.max_seq_len + ): raise ValueError( f'Cannot forward input with past sequence length {past_position} and current sequence length ' + - f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.', ) if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'): @@ -597,8 +660,8 @@ def forward( x = self.emb_drop(x) else: # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414 - x_shrunk = (x * self.embedding_fraction) + ( - x.detach() * (1 - self.embedding_fraction)) + x_shrunk = (x * self.embedding_fraction + ) + (x.detach() * (1 - self.embedding_fraction)) assert isinstance(self.emb_drop, nn.Module) # pyright x = self.emb_drop(x_shrunk) @@ -613,14 +676,17 @@ def forward( S=S, attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl, - attention_mask=attention_mask) + attention_mask=attention_mask, + ) alibi_slopes = None # alibi_slopes will only be used by flash attention for ALiBi if self.alibi and self.attn_impl == 'flash': - alibi_slopes = gen_slopes(n_heads=self.config.n_heads, - alibi_bias_max=self.alibi_bias_max, - device=x.device, - return_1d=True) + alibi_slopes = gen_slopes( + n_heads=self.config.n_heads, + alibi_bias_max=self.alibi_bias_max, + device=x.device, + return_1d=True, + ) # initialize the past key values cache if it should be used presents = () if use_cache else None @@ -633,15 +699,21 @@ def forward( flash_attn_padding_info = {} if self.attn_impl == 'flash': flash_attn_padding_info = gen_flash_attn_padding_info( - bsz, S, past_position, x.device, attention_mask_in_length, - attention_mask) + bsz, + S, + past_position, + x.device, + attention_mask_in_length, + attention_mask, + ) for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) - past_key_value = (past_key_values[b_idx] - if past_key_values is not None else None) + past_key_value = ( + past_key_values[b_idx] if past_key_values is not None else None + ) x, attn_weights, present = block( x, past_key_value=past_key_value, @@ -727,7 +799,7 @@ def __init__(self, config: MPTConfig): logit_scale = 1 / math.sqrt(config.d_model) else: raise ValueError( - f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." + f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.", ) self.logit_scale = logit_scale @@ -735,30 +807,36 @@ def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: return self.transformer.get_input_embeddings() def set_input_embeddings( - self, value: Union[SharedEmbedding, nn.Embedding]) -> None: + self, + value: Union[SharedEmbedding, nn.Embedding], + ) -> None: self.transformer.set_input_embeddings(value) def get_output_embeddings( - self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]: + self, + ) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]: if self.lm_head is not None: return self.lm_head return self.transformer.get_input_embeddings() def set_output_embeddings( - self, new_embeddings: Union[SharedEmbedding, nn.Embedding, - nn.Linear]) -> None: + self, + new_embeddings: Union[SharedEmbedding, nn.Embedding, nn.Linear], + ) -> None: if self.lm_head is not None: self.lm_head = new_embeddings else: if not isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)): raise ValueError( 'new_embeddings must be an instance of SharedEmbedding ' + - f'or nn.Embedding, but got {type(new_embeddings)}.') + f'or nn.Embedding, but got {type(new_embeddings)}.', + ) warnings.warn( 'Using `set_output_embeddings` to set the embedding layer of ' + 'MPTForCausalLM with tied weights. Given weights are tied, ' + 'using `set_input_embeddings` is recommended over using ' + - '`set_output_embeddings`.') + '`set_output_embeddings`.', + ) self.transformer.set_input_embeddings(new_embeddings) def tie_weights(self) -> None: @@ -784,10 +862,12 @@ def forward( use_cache: Optional[bool] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> CausalLMOutputWithPast: - return_dict = (return_dict - if return_dict is not None else self.config.return_dict) - use_cache = (use_cache - if use_cache is not None else self.config.use_cache) + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache + ) outputs = self.transformer( input_ids=input_ids, @@ -813,7 +893,7 @@ def forward( if self.logit_scale is not None: if self.logit_scale == 0: warnings.warn( - f'Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.' + f'Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.', ) logits *= self.logit_scale @@ -883,17 +963,25 @@ def activation_checkpointing_fn(self, module: nn.Module) -> bool: """ if not hasattr(module, 'block_idx'): log.debug( - f'{module.__class__.__name__} cannot be activation checkpointed. Only transformer block or its submodules are eligible for activation checkpointing.' + f'{module.__class__.__name__} cannot be activation checkpointed. Only transformer block or its submodules are eligible for activation checkpointing.', ) return False - act_ckpt_target = getattr(self.config, - 'activation_checkpointing_target', None) + act_ckpt_target = getattr( + self.config, + 'activation_checkpointing_target', + None, + ) act_ckpt_mod_to_blocks = build_act_ckpt_mod_to_blocks( - act_ckpt_target, MPTBlock, module.max_block_idx) + act_ckpt_target, + MPTBlock, + module.max_block_idx, + ) - check_mapping_blocks_overlap(act_ckpt_mod_to_blocks, - module.max_block_idx) + check_mapping_blocks_overlap( + act_ckpt_mod_to_blocks, + module.max_block_idx, + ) for k in act_ckpt_mod_to_blocks.keys(): if isinstance(module, k): @@ -913,7 +1001,8 @@ def prepare_inputs_for_generation( attention_mask = kwargs['attention_mask'].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError( - 'MPT does not support generation with right padding.') + 'MPT does not support generation with right padding.', + ) if self.transformer.attn_uses_sequence_id and self.training: sequence_id = torch.zeros_like(input_ids[:1]) @@ -940,8 +1029,9 @@ def prepare_inputs_for_generation( @staticmethod def _reorder_cache( - past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], - beam_idx: torch.LongTensor) -> List[Tuple[torch.Tensor, ...]]: + past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], + beam_idx: torch.LongTensor, + ) -> List[Tuple[torch.Tensor, ...]]: """Used by HuggingFace generate when using beam search with kv-caching. See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 @@ -952,7 +1042,8 @@ def _reorder_cache( reordered_past += [ tuple( past_state.index_select(0, beam_idx) - for past_state in layer_past) + for past_state in layer_past + ), ] return reordered_past @@ -964,12 +1055,16 @@ def __init__( om_model_config: DictConfig, tokenizer: Optional[PreTrainedTokenizerBase] = None, ): - from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS, - DEFAULT_CAUSAL_LM_TRAIN_METRICS) + from llmfoundry.metrics import ( + DEFAULT_CAUSAL_LM_EVAL_METRICS, + DEFAULT_CAUSAL_LM_TRAIN_METRICS, + ) from llmfoundry.utils.builders import build_metric - resolved_om_model_config = om.to_container(om_model_config, - resolve=True) + resolved_om_model_config = om.to_container( + om_model_config, + resolve=True, + ) assert isinstance(resolved_om_model_config, dict) hf_config = MPTConfig.from_dict(resolved_om_model_config) @@ -977,12 +1072,16 @@ def __init__( use_train_metrics = om_model_config.get('use_train_metrics', True) train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + resolved_om_model_config.get( - 'additional_train_metrics', []) + 'additional_train_metrics', + [], + ) train_metrics = [ build_metric(metric, {}) for metric in train_metric_names ] if use_train_metrics else [] eval_metric_names = DEFAULT_CAUSAL_LM_EVAL_METRICS + resolved_om_model_config.get( - 'additional_eval_metrics', []) + 'additional_eval_metrics', + [], + ) eval_metrics = [ build_metric(metric, {}) for metric in eval_metric_names ] @@ -1003,22 +1102,26 @@ def __init__( from flash_attn.losses.cross_entropy import \ CrossEntropyLoss as FusedCrossEntropyLoss - self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100, - reduction='none') + self.loss_fn = FusedCrossEntropyLoss( + ignore_index=-100, + reduction='none', + ) except: raise ValueError( 'Fused Cross Entropy is not installed. Either (1) have a CUDA-compatible GPU ' + 'and `pip install .[gpu]` if installing from source or `pip install xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.3#subdirectory=csrc/xentropy` ' + - 'if installing from pypi, or (2) set your config model.loss_fn=torch_crossentropy.' + 'if installing from pypi, or (2) set your config model.loss_fn=torch_crossentropy.', ) elif loss_fn_config == 'torch_crossentropy': - self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100, - reduction='none') + self.loss_fn = nn.CrossEntropyLoss( + ignore_index=-100, + reduction='none', + ) else: raise ValueError( - f'Specified loss_fn={self.loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].' + f'Specified loss_fn={self.loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].', ) def get_targets(self, batch: Mapping) -> torch.Tensor: @@ -1033,7 +1136,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: from megablocks.layers.moe import clear_load_balancing_loss except: raise RuntimeError( - 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.' + 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.', ) clear_load_balancing_loss() return self.model( @@ -1046,8 +1149,10 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> Union[dict, torch.Tensor]: targets = self.get_targets(batch) - losses = self.loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), - targets.view(-1)) + losses = self.loss_fn( + outputs.logits.view(-1, outputs.logits.size(-1)), + targets.view(-1), + ) if torch.all(targets == self.loss_fn.ignore_index): loss = losses.sum() @@ -1060,7 +1165,7 @@ def loss(self, outputs: CausalLMOutputWithPast, from megablocks.layers.moe import batched_load_balancing_loss except: raise RuntimeError( - 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.' + 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.', ) lbl = batched_load_balancing_loss(self.model.transformer.mb_args) return { @@ -1090,7 +1195,9 @@ def flops_per_batch(self, batch: Mapping): params = self.n_active_params params_flops_per_token = 2 * params params_flops_per_seq = params_flops_per_token * msl - attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 * - (self.model.config.d_model * (msl**2))) + attn_flops_per_seq = ( + self.model.config.n_layers * 2 * 2 * + (self.model.config.d_model * (msl**2)) + ) return (params_flops_per_seq + attn_flops_per_seq) * 3 * bs diff --git a/llmfoundry/models/utils/__init__.py b/llmfoundry/models/utils/__init__.py index 45a5f757f6..8af39b4338 100644 --- a/llmfoundry/models/utils/__init__.py +++ b/llmfoundry/models/utils/__init__.py @@ -1,14 +1,20 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from llmfoundry.models.utils.act_ckpt import (build_act_ckpt_mod_to_blocks, - check_mapping_blocks_overlap, - pass_on_block_idx) +from llmfoundry.models.utils.act_ckpt import ( + build_act_ckpt_mod_to_blocks, + check_mapping_blocks_overlap, + pass_on_block_idx, +) from llmfoundry.models.utils.config_moe_args import config_moe_args -from llmfoundry.models.utils.meta_init_context import (init_empty_weights, - init_on_device) -from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params, - mpt_get_total_params) +from llmfoundry.models.utils.meta_init_context import ( + init_empty_weights, + init_on_device, +) +from llmfoundry.models.utils.mpt_param_count import ( + mpt_get_active_params, + mpt_get_total_params, +) from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ __all__ = [ diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index ef9a851a09..957b243baf 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -5,9 +5,13 @@ import torch -from llmfoundry.layers_registry import (attention_classes, ffns, - ffns_with_megablocks, ffns_with_norm, - norms) +from llmfoundry.layers_registry import ( + attention_classes, + ffns, + ffns_with_megablocks, + ffns_with_norm, + norms, +) from llmfoundry.models.layers.blocks import FusedNormAttentionNorm, MPTBlock __all__ = [ @@ -48,9 +52,10 @@ def get_act_ckpt_module(mod_name: str) -> Any: list(attention_classes.keys()) + list(ffns.get_all()) + list(ffns_with_norm.get_all()) + list(ffns_with_megablocks.get_all()) + list(norms.get_all()) + - ['MPTBlock']) + ['MPTBlock'], + ) raise ValueError( - f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' + f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.', ) return mod_type @@ -68,7 +73,8 @@ def parse_ele_str(ele: str, max_block_idx: int) -> list: elif ele.startswith('last-'): assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}' to_add = list( - range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1)) + range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1), + ) elif ele.startswith('middle-'): assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}' num = int(ele[7:]) @@ -101,7 +107,7 @@ def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list: candidate_block_ids.extend(to_add) else: raise ValueError( - f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}' + f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}', ) elif isinstance(target_blocks, str): target_blocks = target_blocks.replace(' ', '') @@ -110,7 +116,7 @@ def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list: candidate_block_ids.extend(to_add) else: raise ValueError( - f'target_blocks must be either a single integer, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}' + f'target_blocks must be either a single integer, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}', ) candidate_block_ids = list(set(candidate_block_ids)) @@ -129,14 +135,17 @@ def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None: else: if all_blocks[vv] is not None: raise ValueError( - f'Block {vv} is assigned to both {k} and {all_blocks[vv]}. Each block can only have one granularity of activation checkpointing. Make sure the target_blocks in activation_checkpointing_target do not overlap. For more details, refer to the docs of activation_checkpointing_fn.' + f'Block {vv} is assigned to both {k} and {all_blocks[vv]}. Each block can only have one granularity of activation checkpointing. Make sure the target_blocks in activation_checkpointing_target do not overlap. For more details, refer to the docs of activation_checkpointing_fn.', ) else: all_blocks[vv] = k -def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, - max_block_idx: int) -> dict: +def build_act_ckpt_mod_to_blocks( + act_ckpt_target: Any, + top_module: Any, + max_block_idx: int, +) -> dict: act_ckpt_mod_to_blocks = {} if act_ckpt_target is None or act_ckpt_target == []: mod = top_module @@ -155,7 +164,7 @@ def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, act_ckpt_mod_to_blocks[mod] = block_ids else: raise ValueError( - f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}' + f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}', ) return act_ckpt_mod_to_blocks diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index 859dd3c52b..2d9a8cadd4 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -34,7 +34,8 @@ def create_process_group_ranks(ranks: tuple[int]): distributed.all_gather_object(ranks_gather_list, ranks) ranks_per_subgroup = list(set(ranks_gather_list)) group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration( - ranks_per_subgroup) + ranks_per_subgroup, + ) return group @@ -86,7 +87,7 @@ def config_megablocks_moe_args( import megablocks except: raise RuntimeError( - 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.' + 'Requirements for MegaBlocks not installed; see install instructions in `README.md`.', ) ffn_config.setdefault('fp16', False) @@ -103,10 +104,11 @@ def config_megablocks_moe_args( device_mesh = None device_mesh_cfg = ffn_config.pop('device_mesh', None) if moe_world_size > 1: - if version.parse(torch.__version__.split('.dev')[0]) < version.parse( - '2.2.0'): # type: ignore + if version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.2.0'): # type: ignore raise RuntimeError( - 'MoE world size > 1 is not supported in torch version {torch.__version__}<2.2.' + 'MoE world size > 1 is not supported in torch version {torch.__version__}<2.2.', ) from torch.distributed._tensor.device_mesh import init_device_mesh @@ -114,7 +116,7 @@ def config_megablocks_moe_args( world_size = distributed.get_world_size() if world_size < moe_world_size or world_size % moe_world_size: raise ValueError( - f'Invalid world size configuration: {world_size=} and {moe_world_size=}' + f'Invalid world size configuration: {world_size=} and {moe_world_size=}', ) # FSDP @@ -144,7 +146,7 @@ def config_megablocks_moe_args( lbl_process_group = create_set_process_group(lbl_process_group) elif lbl_process_group is not None: raise ValueError( - f'Unknown {lbl_process_group=}. Options are: none | expert_group | global_group | .' + f'Unknown {lbl_process_group=}. Options are: none | expert_group | global_group | .', ) ffn_config['lbl_process_group'] = lbl_process_group diff --git a/llmfoundry/models/utils/meta_init_context.py b/llmfoundry/models/utils/meta_init_context.py index 66f06db581..cf04150516 100644 --- a/llmfoundry/models/utils/meta_init_context.py +++ b/llmfoundry/models/utils/meta_init_context.py @@ -57,8 +57,10 @@ def init_empty_weights(include_buffers: bool = False): """ - with init_on_device(torch.device('meta'), - include_buffers=include_buffers) as f: + with init_on_device( + torch.device('meta'), + include_buffers=include_buffers, + ) as f: yield f @@ -86,8 +88,11 @@ def init_on_device(device: torch.device, include_buffers: bool = False): if include_buffers: old_register_buffer = nn.Module.register_buffer - def register_empty_parameter(self: torch.nn.Module, name: str, - param: Optional[torch.nn.Parameter]): + def register_empty_parameter( + self: torch.nn.Module, + name: str, + param: Optional[torch.nn.Parameter], + ): old_register_parameter(self, name, param) if param is not None: parameter = self._parameters[name] @@ -97,13 +102,17 @@ def register_empty_parameter(self: torch.nn.Module, name: str, else: param_cls = type(parameter) kwargs = parameter.__dict__ - self._parameters[name] = param_cls(parameter.to(device), - **kwargs) - - def register_empty_buffer(self: torch.nn.Module, - name: str, - tensor: Optional[torch.Tensor], - persistent: bool = True): + self._parameters[name] = param_cls( + parameter.to(device), + **kwargs, + ) + + def register_empty_buffer( + self: torch.nn.Module, + name: str, + tensor: Optional[torch.Tensor], + persistent: bool = True, + ): old_register_buffer(self, name, tensor, persistent=persistent) if tensor is not None: named_buffer = self._buffers[name] @@ -133,8 +142,10 @@ def wrapper(*args: Any, **kwargs: Any): nn.Module.register_buffer = register_empty_buffer for torch_function_name in tensor_constructors_to_patch.keys(): setattr( - torch, torch_function_name, - patch_tensor_constructor(getattr(torch, torch_function_name))) + torch, + torch_function_name, + patch_tensor_constructor(getattr(torch, torch_function_name)), + ) yield finally: nn.Module.register_parameter = old_register_parameter diff --git a/llmfoundry/models/utils/mpt_param_count.py b/llmfoundry/models/utils/mpt_param_count.py index ca487ecca0..d7b61354c7 100644 --- a/llmfoundry/models/utils/mpt_param_count.py +++ b/llmfoundry/models/utils/mpt_param_count.py @@ -72,8 +72,9 @@ def megablocks_n_total_params(mpt_model) -> int: # type: ignore n_total_params = 0 for module in mpt_model.modules(): if isinstance( - module, - (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)): + module, + (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP), + ): n_w1 = _dtensor_safe_check_numel(module.w1) n_total_params += n_w1 * moe_world_size n_w2 = _dtensor_safe_check_numel(module.w2) @@ -116,8 +117,9 @@ def megablocks_n_active_params(mpt_model) -> int: # type: ignore n_active_params = 0 for module in mpt_model.modules(): if isinstance( - module, - (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)): + module, + (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP), + ): n_w1 = _dtensor_safe_check_numel(module.w1) n_active_params += int(n_w1 / local_experts * moe_top_k) n_w2 = _dtensor_safe_check_numel(module.w2) @@ -170,5 +172,6 @@ def mpt_get_active_params(mpt_model) -> int: # type: ignore if not mpt_model.model.transformer.config.tie_word_embeddings: # Embedding layers are lookup tables, therefore are not counted in the FLOP computation params -= _dtensor_safe_check_numel( - mpt_model.model.transformer.wte.weight) + mpt_model.model.transformer.wte.weight, + ) return params diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 06bdd84438..6ff241870d 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -12,8 +12,12 @@ from torch import nn from torch.distributed._tensor import DTensor -from llmfoundry.layers_registry import (fcs, module_init_fns, norms, - param_init_fns) +from llmfoundry.layers_registry import ( + fcs, + module_init_fns, + norms, + param_init_fns, +) from llmfoundry.models.layers.dmoe import GLU, MLP try: @@ -37,8 +41,10 @@ def torch_default_param_init_fn_( ) -> None: del kwargs # unused, just to capture any extra args from the config - if hasattr(module, 'reset_parameters') and isinstance( - module.reset_parameters, Callable): + if hasattr( + module, + 'reset_parameters', + ) and isinstance(module.reset_parameters, Callable): module.reset_parameters() @@ -161,7 +167,7 @@ def fc_init( ) -> bool: del kwargs # unused, just to capture any extra args - if isinstance(module, tuple(set([fcs.get(n) for n in fcs.get_all()]))): + if isinstance(module, tuple({fcs.get(n) for n in fcs.get_all()})): # Linear if hasattr(module, '_fused'): fused_init_helper_(module, init_fn_) @@ -172,7 +178,10 @@ def fc_init( torch.nn.init.zeros_(module.bias) if init_div_is_residual is not False and getattr( - module, '_is_residual', False): + module, + '_is_residual', + False, + ): with torch.no_grad(): module.weight.div_(div_is_residual) # type: ignore return True @@ -201,7 +210,7 @@ def embedding_init( if isinstance(lim, Sequence): if len(lim) > 2: raise ValueError( - f'Uniform init requires a min and a max limit. User input: {lim}.' + f'Uniform init requires a min and a max limit. User input: {lim}.', ) if lim[0] == lim[1]: warnings.warn(f'Embedding layer initialized to {lim[0]}.') @@ -227,11 +236,13 @@ def norm_init( ) -> bool: del kwargs # unused, just to capture any extra args - if isinstance(module, - tuple(set([norms.get(name) for name in norms.get_all()]))): + if isinstance( + module, + tuple({norms.get(name) for name in norms.get_all()}), + ): # Norm - if hasattr(module, 'weight') and isinstance(module.weight, - torch.Tensor): + if hasattr(module, + 'weight') and isinstance(module.weight, torch.Tensor): torch.nn.init.ones_(module.weight) if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor): torch.nn.init.zeros_(module.bias) @@ -280,7 +291,10 @@ def multihead_attention_init( # out proj init_fn_(module.out_proj.weight) if init_div_is_residual is not False and getattr( - module.out_proj, '_is_residual', False): + module.out_proj, + '_is_residual', + False, + ): with torch.no_grad(): module.out_proj.weight.div_(div_is_residual) if module.out_proj.bias is not None: @@ -328,31 +342,51 @@ def moe_init( div_is_residual: float, **kwargs: Any, ) -> bool: - if megablocks is not None and isinstance(module, ( + if megablocks is not None and isinstance( + module, + ( megablocks.layers.moe.MoE, megablocks.layers.dmoe.dMoE, megablocks.layers.moe.ParallelMLP, megablocks.layers.dmoe.ParallelDroplessMLP, - )): + ), + ): if hasattr(module, 'bias') and module.bias is not None: # Initialize bias to 0 torch.nn.init.zeros_(module.bias) # type: ignore return True - elif megablocks is not None and isinstance(module, - megablocks.layers.glu.SparseGLU): + elif megablocks is not None and isinstance( + module, + megablocks.layers.glu.SparseGLU, + ): _megablocks_sparse_glu_generic_param_init_fn_( - module, init_fn_, bool(init_div_is_residual), div_is_residual) + module, + init_fn_, + bool(init_div_is_residual), + div_is_residual, + ) return True - elif megablocks is not None and isinstance(module, - megablocks.layers.mlp.SparseMLP): + elif megablocks is not None and isinstance( + module, + megablocks.layers.mlp.SparseMLP, + ): _megablocks_sparse_mlp_generic_param_init_fn_( - module, init_fn_, bool(init_div_is_residual), div_is_residual) + module, + init_fn_, + bool(init_div_is_residual), + div_is_residual, + ) return True - elif megablocks is not None and isinstance(module, - megablocks.layers.mlp.MLP): - _megablocks_mlp_generic_param_init_fn_(module, init_fn_, - bool(init_div_is_residual), - div_is_residual) + elif megablocks is not None and isinstance( + module, + megablocks.layers.mlp.MLP, + ): + _megablocks_mlp_generic_param_init_fn_( + module, + init_fn_, + bool(init_div_is_residual), + div_is_residual, + ) return True elif isinstance(module, GLU): init_fn_(module.w1) @@ -388,8 +422,8 @@ def generic_param_init_fn_( div_is_residual = 1.0 elif init_div_is_residual is True: div_is_residual = math.sqrt(2 * n_layers) - elif isinstance(init_div_is_residual, float) or isinstance( - init_div_is_residual, int): + elif isinstance(init_div_is_residual, + float) or isinstance(init_div_is_residual, int): div_is_residual = init_div_is_residual elif init_div_is_residual.isnumeric(): # do not trust YAML parsing to always convert numbers to numbers @@ -398,7 +432,7 @@ def generic_param_init_fn_( # not used, for pyright div_is_residual = 1.0 raise ValueError( - f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' + f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}', ) all_module_init_fns = [ @@ -426,7 +460,8 @@ def generic_param_init_fn_( f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' + 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' - + ', '.join(module_init_fns.get_all())) + + ', '.join(module_init_fns.get_all()), + ) def _megablocks_sparse_mlp_generic_param_init_fn_( @@ -452,13 +487,16 @@ def _megablocks_sparse_mlp_generic_param_init_fn_( expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0 if module.expert_parallel_group is not None: expert_process_group_size = int( - module.expert_parallel_group.size()) # type: ignore + module.expert_parallel_group.size(), + ) # type: ignore rank = int(module.expert_parallel_group.rank()) # type: ignore if module.weight_parallel_group is not None: weight_parallel_group_size = int( - module.weight_parallel_group.size()) # type: ignore + module.weight_parallel_group.size(), + ) # type: ignore weight_parallel_group_rank = int( - module.weight_parallel_group.rank()) # type: ignore + module.weight_parallel_group.rank(), + ) # type: ignore hidden_size = int(module.hidden_size) # type: ignore @@ -525,19 +563,23 @@ def _megablocks_sparse_glu_generic_param_init_fn_( module=module, init_fn_=init_fn_, init_div_is_residual=init_div_is_residual, - div_is_residual=div_is_residual) + div_is_residual=div_is_residual, + ) # Init ported from _megablocks_sparse_mlp_generic_param_init_fn_ for v1 expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0 if module.expert_parallel_group is not None: expert_process_group_size = int( - module.expert_parallel_group.size()) # type: ignore + module.expert_parallel_group.size(), + ) # type: ignore rank = int(module.expert_parallel_group.rank()) # type: ignore if module.weight_parallel_group is not None: weight_parallel_group_size = int( - module.weight_parallel_group.size()) # type: ignore + module.weight_parallel_group.size(), + ) # type: ignore weight_parallel_group_rank = int( - module.weight_parallel_group.rank()) # type: ignore + module.weight_parallel_group.rank(), + ) # type: ignore hidden_size = int(module.hidden_size) # type: ignore @@ -584,11 +626,13 @@ def _megablocks_mlp_generic_param_init_fn_( expert_process_group_size, rank, weight_parallel_group_size, w_rank = 1, 0, 1, 0 if module.expert_parallel_group is not None: expert_process_group_size = int( - module.expert_parallel_group.size()) # type: ignore + module.expert_parallel_group.size(), + ) # type: ignore rank = int(module.expert_parallel_group.rank()) # type: ignore if module.weight_parallel_group is not None: weight_parallel_group_size = int( - module.weight_parallel_group.size()) # type: ignore + module.weight_parallel_group.size(), + ) # type: ignore w_rank = int(module.weight_parallel_group.rank()) # type: ignore _init_fn_ = _flip_fan_mode(init_fn_) @@ -660,7 +704,7 @@ def baseline_param_init_fn_( del kwargs # unused, just to capture any extra args from the config if init_std is None: raise ValueError( - "You must set model.init_config['init_std'] to a float value to use the default initialization scheme." + "You must set model.init_config['init_std'] to a float value to use the default initialization scheme.", ) _normal_param_init_fn_( module=module, @@ -738,10 +782,12 @@ def kaiming_uniform_param_init_fn_( ) -> None: del kwargs # unused, just to capture any extra args from the config - kaiming_uniform_ = partial(nn.init.kaiming_uniform_, - a=init_gain, - mode=fan_mode, - nonlinearity=init_nonlinearity) + kaiming_uniform_ = partial( + nn.init.kaiming_uniform_, + a=init_gain, + mode=fan_mode, + nonlinearity=init_nonlinearity, + ) generic_param_init_fn_( module=module, @@ -768,10 +814,12 @@ def kaiming_normal_param_init_fn_( ) -> None: del kwargs # unused, just to capture any extra args from the config - kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, - a=init_gain, - mode=fan_mode, - nonlinearity=init_nonlinearity) + kaiming_normal_ = partial( + torch.nn.init.kaiming_normal_, + a=init_gain, + mode=fan_mode, + nonlinearity=init_nonlinearity, + ) generic_param_init_fn_( module=module, diff --git a/llmfoundry/optim/__init__.py b/llmfoundry/optim/__init__.py index 26389665b5..0b55944338 100644 --- a/llmfoundry/optim/__init__.py +++ b/llmfoundry/optim/__init__.py @@ -1,9 +1,12 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from composer.optim import (ConstantWithWarmupScheduler, - CosineAnnealingWithWarmupScheduler, DecoupledAdamW, - LinearWithWarmupScheduler) +from composer.optim import ( + ConstantWithWarmupScheduler, + CosineAnnealingWithWarmupScheduler, + DecoupledAdamW, + LinearWithWarmupScheduler, +) from llmfoundry.optim.adaptive_lion import DecoupledAdaLRLion, DecoupledClipLion from llmfoundry.optim.lion import DecoupledLionW @@ -16,11 +19,15 @@ optimizers.register('decoupled_adamw', func=DecoupledAdamW) schedulers.register('constant_with_warmup', func=ConstantWithWarmupScheduler) -schedulers.register('cosine_with_warmup', - func=CosineAnnealingWithWarmupScheduler) +schedulers.register( + 'cosine_with_warmup', + func=CosineAnnealingWithWarmupScheduler, +) schedulers.register('linear_decay_with_warmup', func=LinearWithWarmupScheduler) -schedulers.register('inv_sqrt_with_warmup', - func=InverseSquareRootWithWarmupScheduler) +schedulers.register( + 'inv_sqrt_with_warmup', + func=InverseSquareRootWithWarmupScheduler, +) __all__ = [ 'DecoupledLionW', diff --git a/llmfoundry/optim/adaptive_lion.py b/llmfoundry/optim/adaptive_lion.py index 9b2dac9d80..cb4ce59cd0 100644 --- a/llmfoundry/optim/adaptive_lion.py +++ b/llmfoundry/optim/adaptive_lion.py @@ -42,39 +42,41 @@ class DecoupledAdaLRLion(Optimizer): """ metric_functions = { 'l2_norm/moment': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - optim_state['exp_avg']), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(optim_state['exp_avg']), 'l2_norm/param': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - param.data), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(param.data), 'l2_norm/update': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - step_tensor), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(step_tensor), 'l2_norm/grad': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - param.grad), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(param.grad), } - def __init__(self, - params: Union[Iterable[torch.Tensor], Iterable[dict]], - lr: float = 1e-4, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0.0, - outlier_threshold: float = 10.0, - timeout: int = 100, - lr_penalty: float = .707, - min_scale: float = 1e-4): + def __init__( + self, + params: Union[Iterable[torch.Tensor], Iterable[dict]], + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + outlier_threshold: float = 10.0, + timeout: int = 100, + lr_penalty: float = .707, + min_scale: float = 1e-4, + ): if lr <= 0.: raise Exception(f'Invalid LR: {lr}. LR must be > 0') - if not all([0. <= beta <= 1. for beta in betas]): + if not all(0. <= beta <= 1. for beta in betas): raise Exception( - f'Invalid beta values: {betas} All betas must be between 0 and 1.' + f'Invalid beta values: {betas} All betas must be between 0 and 1.', ) if weight_decay >= 1e-3: log.warning( f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? ' + - f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!' + f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!', ) defaults = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay} @@ -89,9 +91,16 @@ def __init__(self, self.min_scale = min_scale @staticmethod - def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, - lr: float, initial_lr: float, wd: float, beta1: float, - beta2: float) -> None: + def lionw( + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, + lr: float, + initial_lr: float, + wd: float, + beta1: float, + beta2: float, + ) -> None: # stepweight decay if wd != 0: decay_factor = (lr / initial_lr) if initial_lr else 1.0 @@ -105,8 +114,12 @@ def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, exp_avg.lerp_(grad, 1 - beta2) @staticmethod - def adjust_lr(lr: float, lr_penalty: float, num_times: int, - min_scale: float) -> float: + def adjust_lr( + lr: float, + lr_penalty: float, + num_times: int, + min_scale: float, + ) -> float: """Adjusts LR. Multiplicatively scales down the LR by lr_penalty for each outlier @@ -133,8 +146,10 @@ def step(self, closure: Optional[Callable] = None): loss = closure() for group in self.param_groups: - for p in filter(lambda p: p.grad is not None and p.requires_grad, - group['params']): + for p in filter( + lambda p: p.grad is not None and p.requires_grad, + group['params'], + ): grad, lr, initial_lr, wd, beta1, beta2, state = p.grad, group[ 'lr'], group['initial_lr'], group[ @@ -145,7 +160,8 @@ def step(self, closure: Optional[Callable] = None): if len(state) == 0: state['exp_avg'] = torch.zeros_like(p) state['moment_tracker'] = OutlierDetector( - self.outlier_threshold) + self.outlier_threshold, + ) state['outlier_timestamp'] = [] state['step'] = 0 @@ -153,7 +169,8 @@ def step(self, closure: Optional[Callable] = None): # determine if the new moment resulting from this grad would be an outlier moment_norm = torch.linalg.vector_norm( - exp_avg.lerp(grad, 1 - beta2))**2 + exp_avg.lerp(grad, 1 - beta2), + )**2 if dist.get_world_size() > 1: dist.all_reduce(moment_norm, reduce_operation='SUM') @@ -162,17 +179,20 @@ def step(self, closure: Optional[Callable] = None): if state['moment_tracker'].insert_observation(moment_norm): state['outlier_timestamp'].append(state['step']) - removed = [] - for ts in state['outlier_timestamp']: - if state['step'] - ts > self.timeout: - removed.append(ts) + removed = [ + ts for ts in state['outlier_timestamp'] + if state['step'] - ts > self.timeout + ] for ts in removed: state['outlier_timestamp'].remove(ts) - lr = self.adjust_lr(lr, self.lr_penalty, - len(state['outlier_timestamp']), - self.min_scale) + lr = self.adjust_lr( + lr, + self.lr_penalty, + len(state['outlier_timestamp']), + self.min_scale, + ) self.lionw(p, grad, exp_avg, lr, initial_lr, wd, beta1, beta2) state['step'] += 1 @@ -197,8 +217,8 @@ def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): A_reduced_norm = optimizer_metrics[f'l2_norm/{A}/{layer}'] B_reduced_norm = optimizer_metrics[f'l2_norm/{B}/{layer}'] - optimizer_metrics[metric] = reduced / (A_reduced_norm * - B_reduced_norm) + optimizer_metrics[ + metric] = reduced / (A_reduced_norm * B_reduced_norm) elif metric.startswith('layerwise_lr'): continue else: @@ -217,8 +237,12 @@ def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): optimizer_metrics[metric] = optimizer_metrics[metric]**2 return optimizer_metrics - def report_per_parameter_metrics(self, param: torch.Tensor, name: str, - optimizer_metrics: dict): + def report_per_parameter_metrics( + self, + param: torch.Tensor, + name: str, + optimizer_metrics: dict, + ): lr = self.param_groups[0]['lr'] weight_decay = self.param_groups[0]['weight_decay'] initial_lr = self.param_groups[0]['initial_lr'] @@ -227,11 +251,16 @@ def report_per_parameter_metrics(self, param: torch.Tensor, name: str, if param in self.state: param_optim_state = self.state[param] layerwise_lr = self.adjust_lr( - lr, self.lr_penalty, - len(param_optim_state['outlier_timestamp']), self.min_scale) + lr, + self.lr_penalty, + len(param_optim_state['outlier_timestamp']), + self.min_scale, + ) step_tensor = param_optim_state['exp_avg'].clone().lerp_( - param.grad, 1 - beta1).sign_().mul_(lr) + param.grad, + 1 - beta1, + ).sign_().mul_(lr) decay_factor = (lr / initial_lr) if initial_lr else 1.0 step_tensor.add_(param, alpha=-weight_decay * decay_factor) for metric in self.metric_functions: @@ -239,7 +268,8 @@ def report_per_parameter_metrics(self, param: torch.Tensor, name: str, metric](param, param_optim_state, step_tensor) optimizer_metrics[f'layerwise_lr/{name}'] = torch.tensor( - layerwise_lr) + layerwise_lr, + ) return optimizer_metrics @@ -263,36 +293,38 @@ class DecoupledClipLion(Optimizer): """ metric_functions = { 'l2_norm/moment': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - optim_state['exp_avg']), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(optim_state['exp_avg']), 'l2_norm/param': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - param.data), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(param.data), 'l2_norm/update': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - step_tensor), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(step_tensor), 'l2_norm/grad': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - param.grad), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(param.grad), } - def __init__(self, - params: Union[Iterable[torch.Tensor], Iterable[dict]], - lr: float = 1e-4, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0.0, - outlier_threshold: float = 5.0): + def __init__( + self, + params: Union[Iterable[torch.Tensor], Iterable[dict]], + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + outlier_threshold: float = 5.0, + ): if lr <= 0.: raise Exception(f'Invalid LR: {lr}. LR must be > 0') - if not all([0. <= beta <= 1. for beta in betas]): + if not all(0. <= beta <= 1. for beta in betas): raise Exception( - f'Invalid beta values: {betas} All betas must be between 0 and 1.' + f'Invalid beta values: {betas} All betas must be between 0 and 1.', ) if weight_decay >= 1e-3: log.warning( f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? ' + - f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!' + f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!', ) defaults = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay} @@ -304,9 +336,16 @@ def __init__(self, self.outlier_threshold = outlier_threshold @staticmethod - def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, - lr: float, initial_lr: float, wd: float, beta1: float, - beta2: float) -> None: + def lionw( + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, + lr: float, + initial_lr: float, + wd: float, + beta1: float, + beta2: float, + ) -> None: # stepweight decay if wd != 0: decay_factor = (lr / initial_lr) if initial_lr else 1.0 @@ -328,8 +367,10 @@ def step(self, closure: Optional[Callable] = None): loss = closure() for group in self.param_groups: - for p in filter(lambda p: p.grad is not None and p.requires_grad, - group['params']): + for p in filter( + lambda p: p.grad is not None and p.requires_grad, + group['params'], + ): grad, lr, initial_lr, wd, beta1, beta2, state = p.grad, group[ 'lr'], group['initial_lr'], group[ @@ -340,7 +381,8 @@ def step(self, closure: Optional[Callable] = None): if len(state) == 0: state['exp_avg'] = torch.zeros_like(p) state['grad_tracker'] = OutlierDetector( - self.outlier_threshold) + self.outlier_threshold, + ) state['clipped_batches'] = torch.tensor(0.0) exp_avg = state['exp_avg'] @@ -393,8 +435,10 @@ def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): """Preprocess metrics to reduce across ranks correctly.""" # Sort L2 norms first so they are squared before other metrics, which depend on squared values metrics = optimizer_metrics.keys() - metrics = sorted(metrics, - key=lambda metric: 0 if 'l2_norm' in metric else 1) + metrics = sorted( + metrics, + key=lambda metric: 0 if 'l2_norm' in metric else 1, + ) for metric in metrics: if metric.startswith('l2_norm'): # L2 norms need to be squared, before they are reduced via summation @@ -406,17 +450,23 @@ def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): # L2 norm would've been squared in previous branch A_rank_subset_norm = math.sqrt( - optimizer_metrics[f'l2_norm/{A}/{layer}']) + optimizer_metrics[f'l2_norm/{A}/{layer}'], + ) B_rank_subset_norm = math.sqrt( - optimizer_metrics[f'l2_norm/{B}/{layer}']) + optimizer_metrics[f'l2_norm/{B}/{layer}'], + ) - optimizer_metrics[ - metric] *= A_rank_subset_norm * B_rank_subset_norm + optimizer_metrics[metric + ] *= A_rank_subset_norm * B_rank_subset_norm return optimizer_metrics - def report_per_parameter_metrics(self, param: torch.Tensor, name: str, - optimizer_metrics: dict): + def report_per_parameter_metrics( + self, + param: torch.Tensor, + name: str, + optimizer_metrics: dict, + ): lr = self.param_groups[0]['lr'] weight_decay = self.param_groups[0]['weight_decay'] initial_lr = self.param_groups[0]['initial_lr'] @@ -425,7 +475,9 @@ def report_per_parameter_metrics(self, param: torch.Tensor, name: str, if param in self.state: param_optim_state = self.state[param] step_tensor = param_optim_state['exp_avg'].clone().lerp_( - param.grad, 1 - beta1).sign_().mul_(lr) + param.grad, + 1 - beta1, + ).sign_().mul_(lr) decay_factor = (lr / initial_lr) if initial_lr else 1.0 step_tensor.add_(param, alpha=-weight_decay * decay_factor) for metric in self.metric_functions: diff --git a/llmfoundry/optim/lion.py b/llmfoundry/optim/lion.py index b04211649c..667c3f55b1 100644 --- a/llmfoundry/optim/lion.py +++ b/llmfoundry/optim/lion.py @@ -19,37 +19,37 @@ class DecoupledLionW(Optimizer): metric_functions = { 'l2_norm/moment': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - optim_state['exp_avg']), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(optim_state['exp_avg']), 'l2_norm/param': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - param.data), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(param.data), 'l2_norm/update': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - step_tensor), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(step_tensor), 'l2_norm/grad': - lambda param, optim_state, step_tensor: torch.linalg.vector_norm( - param.grad), + lambda param, optim_state, step_tensor: torch.linalg. + vector_norm(param.grad), } def __init__( - self, - params: Union[Iterable[torch.Tensor], Iterable[dict]], - lr: float = 1e-4, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0.0, + self, + params: Union[Iterable[torch.Tensor], Iterable[dict]], + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, ): if lr <= 0.: raise Exception(f'Invalid LR: {lr}. LR must be > 0') - if not all([0. <= beta <= 1. for beta in betas]): + if not all(0. <= beta <= 1. for beta in betas): raise Exception( - f'Invalid beta values: {betas} All betas must be between 0 and 1.' + f'Invalid beta values: {betas} All betas must be between 0 and 1.', ) if weight_decay >= 1e-3: log.warning( f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? ' + - f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!' + f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!', ) defaults = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay} @@ -60,9 +60,16 @@ def __init__( group['initial_lr'] = group['lr'] @staticmethod - def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, - lr: float, initial_lr: float, wd: float, beta1: float, - beta2: float) -> None: + def lionw( + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, + lr: float, + initial_lr: float, + wd: float, + beta1: float, + beta2: float, + ) -> None: # stepweight decay if wd != 0: decay_factor = (lr / initial_lr) if initial_lr else 1.0 @@ -84,8 +91,10 @@ def step(self, closure: Optional[Callable] = None): loss = closure() for group in self.param_groups: - for p in filter(lambda p: p.grad is not None and p.requires_grad, - group['params']): + for p in filter( + lambda p: p.grad is not None and p.requires_grad, + group['params'], + ): grad, lr, initial_lr, wd, beta1, beta2, state = p.grad, group[ 'lr'], group['initial_lr'], group[ @@ -135,8 +144,12 @@ def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): optimizer_metrics[metric] = optimizer_metrics[metric]**2 return optimizer_metrics - def report_per_parameter_metrics(self, param: torch.Tensor, name: str, - optimizer_metrics: dict): + def report_per_parameter_metrics( + self, + param: torch.Tensor, + name: str, + optimizer_metrics: dict, + ): lr = self.param_groups[0]['lr'] weight_decay = self.param_groups[0]['weight_decay'] initial_lr = self.param_groups[0]['initial_lr'] @@ -145,7 +158,9 @@ def report_per_parameter_metrics(self, param: torch.Tensor, name: str, if param in self.state: param_optim_state = self.state[param] step_tensor = param_optim_state['exp_avg'].clone().lerp_( - param.grad, 1 - beta1).sign_().mul_(lr) + param.grad, + 1 - beta1, + ).sign_().mul_(lr) decay_factor = (lr / initial_lr) if initial_lr else 1.0 step_tensor.add_(param, alpha=-weight_decay * decay_factor) for metric in self.metric_functions: diff --git a/llmfoundry/optim/outlier_detection.py b/llmfoundry/optim/outlier_detection.py index e430f4ccb5..38567b7acd 100644 --- a/llmfoundry/optim/outlier_detection.py +++ b/llmfoundry/optim/outlier_detection.py @@ -46,8 +46,9 @@ def insert_observation(self, obs: float) -> bool: """ assert self.intermediate_data_queue.maxlen is not None, 'expected maxlen defined' - if len(self.intermediate_data_queue - ) >= self.intermediate_data_queue.maxlen: + if len( + self.intermediate_data_queue, + ) >= self.intermediate_data_queue.maxlen: # move data from intermediate queue to slow moving average queue intermediate_obs = self.intermediate_data_queue.popleft() self.delayed_moving_average.append(intermediate_obs) @@ -58,7 +59,8 @@ def insert_observation(self, obs: float) -> bool: def get_delayed_mva(self) -> Optional[float]: if len(self.delayed_moving_average) > 0: - return sum(self.delayed_moving_average) / len( - self.delayed_moving_average) + return sum( + self.delayed_moving_average, + ) / len(self.delayed_moving_average) else: return None diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py index 655093d138..5c45fc1eca 100644 --- a/llmfoundry/optim/scheduler.py +++ b/llmfoundry/optim/scheduler.py @@ -16,14 +16,17 @@ __all__ = ['InverseSquareRootWithWarmupScheduler'] -def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time], - name: str) -> None: +def _raise_if_units_dont_match( + time: Union[str, Time], + t_max: Union[str, Time], + name: str, +) -> None: new_time = Time.from_input(time) new_t_max = Time.from_input(t_max) if new_time.unit != new_t_max.unit: raise ValueError( - f'{name} (unit {new_time.unit=}) must match max_duration unit ({new_t_max.unit=}).' + f'{name} (unit {new_time.unit=}) must match max_duration unit ({new_t_max.unit=}).', ) @@ -77,17 +80,21 @@ class InverseSquareRootWithWarmupScheduler(ComposerScheduler): alpha_f_cooldown (float): The learning rate multiplier to decay linear cooldown to. Default = ``0.0``. """ - def __init__(self, - t_warmup: Union[str, Time], - t_scale: Union[str, Time], - t_cooldown: Union[str, Time], - t_max: Union[str, Time] = '1dur', - alpha_f_decay: float = 0.0, - alpha_f_cooldown: float = 0.0) -> None: + def __init__( + self, + t_warmup: Union[str, Time], + t_scale: Union[str, Time], + t_cooldown: Union[str, Time], + t_max: Union[str, Time] = '1dur', + alpha_f_decay: float = 0.0, + alpha_f_cooldown: float = 0.0, + ) -> None: if alpha_f_decay < alpha_f_cooldown: - raise ValueError(('Required: alpha_f_decay >= alpha_f_cooldown. ' - f'Current: alpha_f_decay={alpha_f_decay}, ' - f'alpha_f_cooldown={alpha_f_cooldown}.')) + raise ValueError(( + 'Required: alpha_f_decay >= alpha_f_cooldown. ' + f'Current: alpha_f_decay={alpha_f_decay}, ' + f'alpha_f_cooldown={alpha_f_cooldown}.' + )) _raise_if_units_dur(t_warmup, 't_warmup') _raise_if_units_dur(t_scale, 't_scale') _raise_if_units_dur(t_cooldown, 't_cooldown') @@ -97,25 +104,36 @@ def __init__(self, self.t_max = t_max self.alpha_f_decay = alpha_f_decay self.alpha_f_cooldown = alpha_f_cooldown - self.warmup_scheduler = LinearScheduler(alpha_i=0.0, - alpha_f=1.0, - t_max=t_warmup) + self.warmup_scheduler = LinearScheduler( + alpha_i=0.0, + alpha_f=1.0, + t_max=t_warmup, + ) def __call__(self, state: State, ssr: float = 1.0) -> float: assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked' - _raise_if_units_dont_match(self.t_warmup, state.max_duration, - 't_warmup') + _raise_if_units_dont_match( + self.t_warmup, + state.max_duration, + 't_warmup', + ) _raise_if_units_dont_match(self.t_scale, state.max_duration, 't_scale') - _raise_if_units_dont_match(self.t_cooldown, state.max_duration, - 't_cooldown') + _raise_if_units_dont_match( + self.t_cooldown, + state.max_duration, + 't_cooldown', + ) t_warmup = _convert_time(self.t_warmup, state) if t_warmup.value == 0: warnings.warn( - textwrap.dedent("""\ + textwrap.dedent( + """\ The warmup duration is 0. If warmup was specified as a fraction of the total training duration, the warmup duration is calculated in the - same unit as the trainer's max_duration parameter.""")) + same unit as the trainer's max_duration parameter.""", + ), + ) if state.timestamp < t_warmup: return self.warmup_scheduler(state) @@ -136,8 +154,9 @@ def __call__(self, state: State, ssr: float = 1.0) -> float: # elapsed after warmup, rescaled by the time scale, such that, at # infinite time, the LR decays to alpha_f_decay. coeff = 1 / ((current_time + t_shift) / t_scale).value**0.5 - current_factor = (self.alpha_f_decay + coeff * - (1.0 - self.alpha_f_decay)) + current_factor = ( + self.alpha_f_decay + coeff * (1.0 - self.alpha_f_decay) + ) return current_factor else: @@ -152,6 +171,7 @@ def __call__(self, state: State, ssr: float = 1.0) -> float: frac_of_cooldown = ((current_time - t_cooldown_start) / t_cooldown).value frac_of_cooldown = min(1.0, frac_of_cooldown) - current_factor = (alpha_i + frac_of_cooldown * - (self.alpha_f_cooldown - alpha_i)) + current_factor = ( + alpha_i + frac_of_cooldown * (self.alpha_f_cooldown - alpha_i) + ) return current_factor diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 6e1824ea08..eb971a61af 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -12,10 +12,17 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.layers_registry import (attention_classes, - attention_implementations, fcs, ffns, - ffns_with_megablocks, ffns_with_norm, - module_init_fns, norms, param_init_fns) +from llmfoundry.layers_registry import ( + attention_classes, + attention_implementations, + fcs, + ffns, + ffns_with_megablocks, + ffns_with_norm, + module_init_fns, + norms, + param_init_fns, +) from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -25,11 +32,13 @@ + 'will be constructed by directly passing along the specified kwargs to the constructor.' ) -loggers = create_registry('llmfoundry', - 'loggers', - generic_type=Type[LoggerDestination], - entry_points=True, - description=_loggers_description) +loggers = create_registry( + 'llmfoundry', + 'loggers', + generic_type=Type[LoggerDestination], + entry_points=True, + description=_loggers_description, +) _callbacks_description = ( 'The callbacks registry is used to register classes that implement the Callback interface. ' @@ -38,11 +47,13 @@ + 'The callbacks will be constructed by directly passing along the specified kwargs to the constructor.' ) -callbacks = create_registry('llmfoundry', - 'callbacks', - generic_type=Type[Callback], - entry_points=True, - description=_callbacks_description) +callbacks = create_registry( + 'llmfoundry', + 'callbacks', + generic_type=Type[Callback], + entry_points=True, + description=_callbacks_description, +) _callbacks_with_config_description = ( 'The callbacks_with_config registry is used to register classes that implement the CallbackWithConfig interface. ' @@ -53,40 +64,50 @@ 'llm_foundry.callbacks_with_config', generic_type=Type[CallbackWithConfig], entry_points=True, - description=_callbacks_with_config_description) + description=_callbacks_with_config_description, +) _optimizers_description = ( 'The optimizers registry is used to register classes that implement the Optimizer interface. ' + 'The optimizer will be passed to the optimizers arg of the Trainer. The optimizer will be constructed by directly passing along the ' - + 'specified kwargs to the constructor, along with the model parameters.') -optimizers = create_registry('llmfoundry', - 'optimizers', - generic_type=Type[Optimizer], - entry_points=True, - description=_optimizers_description) + + 'specified kwargs to the constructor, along with the model parameters.' +) +optimizers = create_registry( + 'llmfoundry', + 'optimizers', + generic_type=Type[Optimizer], + entry_points=True, + description=_optimizers_description, +) _algorithms_description = ( 'The algorithms registry is used to register classes that implement the Algorithm interface. ' + 'The algorithm will be passed to the algorithms arg of the Trainer. The algorithm will be constructed by directly passing along the ' - + 'specified kwargs to the constructor.') -algorithms = create_registry('llmfoundry', - 'algorithms', - generic_type=Type[Algorithm], - entry_points=True, - description=_algorithms_description) + + 'specified kwargs to the constructor.' +) +algorithms = create_registry( + 'llmfoundry', + 'algorithms', + generic_type=Type[Algorithm], + entry_points=True, + description=_algorithms_description, +) _schedulers_description = ( 'The schedulers registry is used to register classes that implement the ComposerScheduler interface. ' + 'The scheduler will be passed to the schedulers arg of the Trainer. The scheduler will be constructed by directly passing along the ' - + 'specified kwargs to the constructor.') -schedulers = create_registry('llmfoundry', - 'schedulers', - generic_type=Type[ComposerScheduler], - entry_points=True, - description=_schedulers_description) + + 'specified kwargs to the constructor.' +) +schedulers = create_registry( + 'llmfoundry', + 'schedulers', + generic_type=Type[ComposerScheduler], + entry_points=True, + description=_schedulers_description, +) _models_description = ( 'The models registry is used to register classes that implement the ComposerModel interface. ' @@ -95,11 +116,13 @@ + 'Note: This will soon be updated to take in named kwargs instead of a config directly.' ) -models = create_registry('llmfoundry', - 'models', - generic_type=Type[ComposerModel], - entry_points=True, - description=_models_description) +models = create_registry( + 'llmfoundry', + 'models', + generic_type=Type[ComposerModel], + entry_points=True, + description=_models_description, +) _dataloaders_description = ( 'The dataloaders registry is used to register functions that create a DataSpec. The function should take ' @@ -111,16 +134,19 @@ 'dataloaders', generic_type=Callable[[DictConfig, PreTrainedTokenizerBase, int], DataSpec], entry_points=True, - description=_dataloaders_description) + description=_dataloaders_description, +) _metrics_description = ( 'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.' ) -metrics = create_registry('llmfoundry', - 'metrics', - generic_type=Type[Metric], - entry_points=True, - description=_metrics_description) +metrics = create_registry( + 'llmfoundry', + 'metrics', + generic_type=Type[Metric], + entry_points=True, + description=_metrics_description, +) __all__ = [ 'loggers', diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py index 0ecaa45b5f..f087664344 100644 --- a/llmfoundry/tokenizers/tiktoken.py +++ b/llmfoundry/tokenizers/tiktoken.py @@ -29,11 +29,13 @@ def bytes_to_unicode(): 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. """ - bs = (list(range(ord('!'), - ord('~') + 1)) + list(range(ord('¡'), - ord('¬') + 1)) + - list(range(ord('®'), - ord('ÿ') + 1))) + bs = ( + list(range(ord('!'), + ord('~') + 1)) + list(range(ord('¡'), + ord('¬') + 1)) + + list(range(ord('®'), + ord('ÿ') + 1)) + ) cs = bs[:] n = 0 for b in range(2**8): @@ -55,18 +57,20 @@ class TiktokenTokenizerWrapper(PreTrainedTokenizer): model_input_names = ['input_ids', 'attention_mask'] - def __init__(self, - model_name: Optional[str] = None, - encoding_name: Optional[str] = None, - add_bos_token: bool = False, - add_eos_token: bool = False, - use_default_system_prompt: bool = False, - unk_token: Optional[str] = '<|endoftext|>', - eos_token: Optional[str] = '<|endoftext|>', - bos_token: Optional[str] = '<|endoftext|>', - pad_token: Optional[str] = None, - errors: str = 'replace', - **kwargs: Any): + def __init__( + self, + model_name: Optional[str] = None, + encoding_name: Optional[str] = None, + add_bos_token: bool = False, + add_eos_token: bool = False, + use_default_system_prompt: bool = False, + unk_token: Optional[str] = '<|endoftext|>', + eos_token: Optional[str] = '<|endoftext|>', + bos_token: Optional[str] = '<|endoftext|>', + pad_token: Optional[str] = None, + errors: str = 'replace', + **kwargs: Any, + ): """Constructor creates a tiktoken tokenizer to use as the underlying. tokenizer. @@ -91,7 +95,8 @@ def __init__(self, import tiktoken except: raise ImportError( - 'You need to install tiktoken to use TiktokenTokenizerWrapper.') + 'You need to install tiktoken to use TiktokenTokenizerWrapper.', + ) # Workaround to make tiktokenizer picklable. # https://github.com/huggingface/datasets/issues/5536#issuecomment-1682309347 @@ -102,17 +107,22 @@ def __init__(self, from tiktoken import Encoding # type: ignore (thirdParty) def pickle_Encoding(enc: Encoding): - return (functools.partial(Encoding, - enc.name, - pat_str=enc._pat_str, - mergeable_ranks=enc._mergeable_ranks, - special_tokens=enc._special_tokens), ()) + return ( + functools.partial( + Encoding, + enc.name, + pat_str=enc._pat_str, + mergeable_ranks=enc._mergeable_ranks, + special_tokens=enc._special_tokens, + ), + (), + ) copyreg.pickle(Encoding, pickle_Encoding) if model_name is not None and encoding_name is not None: raise ValueError( - 'You need to specify either model_name or encoding_name, not both.' + 'You need to specify either model_name or encoding_name, not both.', ) self.model_name = model_name @@ -126,7 +136,8 @@ def pickle_Encoding(enc: Encoding): self.encoding_name) else: raise ValueError( - 'You need to specify either model_name or encoding_name.') + 'You need to specify either model_name or encoding_name.', + ) self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token @@ -155,17 +166,19 @@ def pickle_Encoding(enc: Encoding): if i in self.decoder: self.encoder[self.decoder[i]] = i - super().__init__(model_name=model_name, - encoding_name=encoding_name, - add_bos_token=add_bos_token, - add_eos_token=add_eos_token, - use_default_system_prompt=use_default_system_prompt, - unk_token=unk_token, - eos_token=eos_token, - bos_token=bos_token, - pad_token=pad_token, - errors=errors, - **kwargs) + super().__init__( + model_name=model_name, + encoding_name=encoding_name, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + use_default_system_prompt=use_default_system_prompt, + unk_token=unk_token, + eos_token=eos_token, + bos_token=bos_token, + pad_token=pad_token, + errors=errors, + **kwargs, + ) @property def vocab_size(self) -> int: @@ -205,12 +218,16 @@ def default_chat_template(self): '{% if (add_generation_prompt == true and loop.last) %}' "{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}" '{% endif %}' - '{% endfor %}') + '{% endfor %}' + ) template = template.replace( 'USE_DEFAULT_PROMPT', - 'true' if self.use_default_system_prompt else 'false') - template = template.replace('DEFAULT_SYSTEM_PROMPT', - DEFAULT_SYSTEM_PROMPT) + 'true' if self.use_default_system_prompt else 'false', + ) + template = template.replace( + 'DEFAULT_SYSTEM_PROMPT', + DEFAULT_SYSTEM_PROMPT, + ) return template def get_vocab(self) -> Dict[str, int]: @@ -221,8 +238,9 @@ def get_vocab(self) -> Dict[str, int]: vocab_clone = self.encoder.copy() extra_id_index = 0 candidate_extra_id = f'' - indices_to_fill_in = {i for i in range(self.vocab_size)} - set( - vocab_clone.values()) + indices_to_fill_in = ( + set(range(self.vocab_size)) - set(vocab_clone.values()) + ) # Add enough indices to make get_vocab() the right length for index_to_add in indices_to_fill_in: @@ -240,7 +258,8 @@ def _tokenize(self, text: str) -> List[str]: """Returns a tokenized string.""" if not isinstance(text, str): raise ValueError( - f'Expected a string input to _tokenize but got {type(text)}.') + f'Expected a string input to _tokenize but got {type(text)}.', + ) tokens = [ self.decoder[t] @@ -264,13 +283,14 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str: """Converts a sequence of tokens (string) in a single string.""" text = ''.join(tokens) text = bytearray([self.byte_decoder[c] for c in text - ]).decode('utf-8', errors=self.errors) + ],).decode('utf-8', errors=self.errors) return text def build_inputs_with_special_tokens( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None) -> List[int]: + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + ) -> List[int]: bos_token_id = [self.bos_token_id] if self.add_bos_token else [] eos_token_id = [self.eos_token_id] if self.add_eos_token else [] @@ -282,10 +302,11 @@ def build_inputs_with_special_tokens( return output def get_special_tokens_mask( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, - already_has_special_tokens: bool = False) -> List[int]: + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: """Retrieves sequence ids from a token list that has no special tokens. Function copied from @@ -309,29 +330,35 @@ def get_special_tokens_mask( return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, - already_has_special_tokens=True) + already_has_special_tokens=True, + ) bos_token_id = [1] if self.add_bos_token else [] eos_token_id = [1] if self.add_eos_token else [] if token_ids_1 is None: return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + - bos_token_id + ([0] * len(token_ids_1)) + eos_token_id) + return ( + bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id + ) def create_token_type_ids_from_sequences( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None) -> List[int]: + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + ) -> List[int]: sep = [self.sep_token_id] if token_ids_1 is None: return len(token_ids_0 + sep) * [0] return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] - def save_vocabulary(self, - save_directory: str, - filename_prefix: Optional[str] = None) -> Tuple[str]: + def save_vocabulary( + self, + save_directory: str, + filename_prefix: Optional[str] = None, + ) -> Tuple[str]: # ignore the below type to keep the original signature # we are knowingly breaking the signature here, although not 100% certain diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index 2c3d7c9bc3..dd43efcdd7 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -1,44 +1,63 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from llmfoundry.utils.builders import (build_algorithm, build_callback, - build_composer_model, build_evaluators, - build_icl_data_and_gauntlet, - build_icl_evaluators, build_logger, - build_metric, build_optimizer, - build_scheduler, build_tokenizer) +from llmfoundry.utils.builders import ( + build_algorithm, + build_callback, + build_composer_model, + build_evaluators, + build_icl_data_and_gauntlet, + build_icl_evaluators, + build_logger, + build_metric, + build_optimizer, + build_scheduler, + build_tokenizer, +) from llmfoundry.utils.checkpoint_conversion_helpers import ( - convert_and_save_ft_weights, get_hf_tokenizer_from_composer_state_dict, - load_tokenizer) -from llmfoundry.utils.config_utils import (calculate_batch_size_info, - log_config, pop_config, - process_init_device, - update_batch_size_info) -from llmfoundry.utils.data_prep_utils import (DownloadingIterable, - merge_shard_groups) + convert_and_save_ft_weights, + get_hf_tokenizer_from_composer_state_dict, + load_tokenizer, +) +from llmfoundry.utils.config_utils import ( + calculate_batch_size_info, + log_config, + pop_config, + process_init_device, + update_batch_size_info, +) +from llmfoundry.utils.data_prep_utils import ( + DownloadingIterable, + merge_shard_groups, +) from llmfoundry.utils.huggingface_hub_utils import \ edit_files_for_hf_compatibility from llmfoundry.utils.logging_utils import SpecificWarningFilter from llmfoundry.utils.model_download_utils import ( - download_from_hf_hub, download_from_http_fileserver, download_from_oras) - -# isort: off + download_from_hf_hub, + download_from_http_fileserver, + download_from_oras, +) from llmfoundry.utils.mosaicml_logger_utils import ( find_mosaicml_logger, log_eval_analytics, log_train_analytics, maybe_create_mosaicml_logger, ) -# isort: on from llmfoundry.utils.prompt_files import load_prompts, load_prompts_from_file -from llmfoundry.utils.registry_utils import (TypedRegistry, - construct_from_registry, - create_registry, import_file, - save_registry) -from llmfoundry.utils.warnings import (ExperimentalWarning, - VersionedDeprecationWarning, - experimental_class, - experimental_function) +from llmfoundry.utils.registry_utils import ( + TypedRegistry, + construct_from_registry, + create_registry, + import_file, + save_registry, +) +from llmfoundry.utils.warnings import ( + ExperimentalWarning, + VersionedDeprecationWarning, + experimental_class, + experimental_function, +) __all__ = [ 'build_algorithm', diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index bc7eda350d..1c9dbc54a3 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -8,8 +8,16 @@ import re import warnings from collections import OrderedDict -from typing import (Any, ContextManager, Dict, Iterable, List, Optional, Tuple, - Union) +from typing import ( + Any, + ContextManager, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) import torch from composer.core import Algorithm, Callback, Evaluator @@ -98,8 +106,11 @@ def build_eval_loaders( is_multi_eval = False for eval_config in eval_configs: - eval_dataloader = build_dataloader(eval_config, tokenizer, - device_eval_batch_size) + eval_dataloader = build_dataloader( + eval_config, + tokenizer, + device_eval_batch_size, + ) eval_loader: Evaluator = Evaluator( label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', dataloader=eval_dataloader, @@ -134,14 +145,15 @@ def build_icl_data_and_gauntlet( tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, - icl_subset_num_batches: Optional[int] = None + icl_subset_num_batches: Optional[int] = None, ) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]: icl_evaluators, logger_keys = build_icl_evaluators( icl_tasks_config, tokenizer, icl_seq_len, device_eval_batch_size, - icl_subset_num_batches=icl_subset_num_batches) + icl_subset_num_batches=icl_subset_num_batches, + ) eval_gauntlet_cb = None if eval_gauntlet_config is not None: if isinstance(eval_gauntlet_config, str): @@ -152,7 +164,7 @@ def build_icl_data_and_gauntlet( eval_gauntlet = eval_gauntlet_config else: raise ValueError( - f'Got invalid type for eval_gauntlet_config: {type(eval_gauntlet_config)}' + f'Got invalid type for eval_gauntlet_config: {type(eval_gauntlet_config)}', ) eval_gauntlet.logger_keys = logger_keys eval_gauntlet.benchmark_sizes = { @@ -192,7 +204,7 @@ def build_composer_model( post_validation_function=None, kwargs={ 'om_model_config': cfg, - 'tokenizer': tokenizer + 'tokenizer': tokenizer, }, ) @@ -207,7 +219,8 @@ def build_composer_model( if master_weights_dtype not in str_dtype_to_torch_dtype: raise ValueError( f'Invalid master_weights_dtype: {master_weights_dtype}. ' + - f'Valid options are: {list(str_dtype_to_torch_dtype.keys())}.') + f'Valid options are: {list(str_dtype_to_torch_dtype.keys())}.', + ) dtype = str_dtype_to_torch_dtype[master_weights_dtype] model = model.to(dtype=dtype) @@ -226,49 +239,61 @@ def build_callback( kwargs = {} if 'train_config' in kwargs: raise ValueError( - f'`train_config` is a reserved keyword for callbacks with config. Please remove it from the kwargs.' + f'`train_config` is a reserved keyword for callbacks with config. Please remove it from the kwargs.', ) kwargs['train_config'] = train_config registry_to_use = registry.callbacks_with_config - return construct_from_registry(name=name, - registry=registry_to_use, - partial_function=True, - pre_validation_function=Callback, - post_validation_function=None, - kwargs=kwargs) + return construct_from_registry( + name=name, + registry=registry_to_use, + partial_function=True, + pre_validation_function=Callback, + post_validation_function=None, + kwargs=kwargs, + ) -def build_logger(name: str, - kwargs: Optional[Dict[str, Any]] = None) -> LoggerDestination: +def build_logger( + name: str, + kwargs: Optional[Dict[str, Any]] = None, +) -> LoggerDestination: """Builds a logger from the registry.""" - return construct_from_registry(name=name, - registry=registry.loggers, - partial_function=True, - pre_validation_function=LoggerDestination, - post_validation_function=None, - kwargs=kwargs) + return construct_from_registry( + name=name, + registry=registry.loggers, + partial_function=True, + pre_validation_function=LoggerDestination, + post_validation_function=None, + kwargs=kwargs, + ) -def build_algorithm(name: str, - kwargs: Optional[Dict[str, Any]] = None) -> Algorithm: +def build_algorithm( + name: str, + kwargs: Optional[Dict[str, Any]] = None, +) -> Algorithm: """Builds an algorithm from the registry.""" - return construct_from_registry(name=name, - registry=registry.algorithms, - partial_function=True, - pre_validation_function=Algorithm, - post_validation_function=None, - kwargs=kwargs) + return construct_from_registry( + name=name, + registry=registry.algorithms, + partial_function=True, + pre_validation_function=Algorithm, + post_validation_function=None, + kwargs=kwargs, + ) def build_metric(name: str, kwargs: Optional[Dict[str, Any]] = None) -> Metric: """Builds a metric from the registry.""" - return construct_from_registry(name=name, - registry=registry.metrics, - partial_function=True, - pre_validation_function=Metric, - post_validation_function=None, - kwargs=kwargs) + return construct_from_registry( + name=name, + registry=registry.metrics, + partial_function=True, + pre_validation_function=Metric, + post_validation_function=None, + kwargs=kwargs, + ) def _extract_param_groups( @@ -375,9 +400,10 @@ def _extract_param_groups( def build_optimizer( - model: torch.nn.Module, - name: str, - optimizer_config: Optional[Dict[str, Any]] = None) -> Optimizer: + model: torch.nn.Module, + name: str, + optimizer_config: Optional[Dict[str, Any]] = None, +) -> Optimizer: params = _extract_param_groups(model, optimizer_config) kwargs = optimizer_config @@ -387,21 +413,24 @@ def build_optimizer( if 'params' in kwargs: raise ValueError( 'The `params` will be automatically extracted from the model and ' + - 'optimizer config. Please remove it from the optimizer config kwargs.' + 'optimizer config. Please remove it from the optimizer config kwargs.', ) kwargs['params'] = params - return construct_from_registry(name=name, - registry=registry.optimizers, - partial_function=True, - pre_validation_function=Optimizer, - post_validation_function=None, - kwargs=kwargs) + return construct_from_registry( + name=name, + registry=registry.optimizers, + partial_function=True, + pre_validation_function=Optimizer, + post_validation_function=None, + kwargs=kwargs, + ) def build_scheduler( - name: str, - scheduler_config: Optional[Dict[str, Any]] = None) -> ComposerScheduler: + name: str, + scheduler_config: Optional[Dict[str, Any]] = None, +) -> ComposerScheduler: return construct_from_registry( name=name, registry=registry.schedulers, @@ -413,8 +442,9 @@ def build_scheduler( def build_tokenizer( - tokenizer_name: str, - tokenizer_kwargs: Dict[str, Any]) -> PreTrainedTokenizerBase: + tokenizer_name: str, + tokenizer_kwargs: Dict[str, Any], +) -> PreTrainedTokenizerBase: os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' os.environ['TOKENIZERS_PARALLELISM'] = 'false' @@ -429,8 +459,10 @@ def build_tokenizer( if tokenizer_name.startswith('tiktoken'): tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs) else: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, - **tokenizer_kwargs) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + **tokenizer_kwargs, + ) # HuggingFace does not respect the model_max_length kwarg, and overrides it with # min(kwargs['model_max_length'], original_config['model_max_length']), so we @@ -442,7 +474,8 @@ def build_tokenizer( if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None: raise ValueError( - f'The tokenizer {tokenizer_name} must have an eos_token.') + f'The tokenizer {tokenizer_name} must have an eos_token.', + ) if dist.is_available() and dist.is_initialized( ) and dist.get_world_size() > 1: @@ -492,26 +525,28 @@ def _validate_cfg(icl_cfg: DictConfig): icl_cfg.metric_names = ['InContextLearningLMAccuracy'] elif icl_cfg.icl_task_type == 'multiple_choice': icl_cfg.metric_names = [ - 'InContextLearningMultipleChoiceAccuracy' + 'InContextLearningMultipleChoiceAccuracy', ] elif icl_cfg.icl_task_type == 'schema': icl_cfg.metric_names = [ - 'InContextLearningMultipleChoiceAccuracy' + 'InContextLearningMultipleChoiceAccuracy', ] elif icl_cfg.icl_task_type == 'generation_task_with_answers' or icl_cfg.icl_task_type == 'question_answering': if icl_cfg.icl_task_type == 'question_answering': warnings.warn( VersionedDeprecationWarning( "ICL task type 'question_answering' is now deprecated. Use identifier 'generation_task_with_answers'", - 'v0.9.0')) + 'v0.9.0', + ), + ) icl_cfg.metric_names = [ - 'InContextLearningGenerationExactMatchAccuracy' + 'InContextLearningGenerationExactMatchAccuracy', ] elif icl_cfg.icl_task_type == 'code_evaluation': icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy'] else: raise ValueError( - f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.' + f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.', ) if 'prompt_string' not in icl_cfg: @@ -556,13 +591,18 @@ def _validate_cfg(icl_cfg: DictConfig): hf_parsing_map = icl_cfg.get('hf_parsing_map', {}) hf_loading_vars = icl_cfg.get('hf_loading_vars', {}) - early_stopping_criteria = icl_cfg.get('early_stopping_criteria', - None) + early_stopping_criteria = icl_cfg.get( + 'early_stopping_criteria', + None, + ) if isinstance(early_stopping_criteria, ListConfig): early_stopping_criteria = om.to_container( - early_stopping_criteria) + early_stopping_criteria, + ) assert early_stopping_criteria is None or isinstance( - early_stopping_criteria, list) + early_stopping_criteria, + list, + ) dataloaders = get_icl_task_dataloader( icl_cfg.icl_task_type, icl_cfg.dataset_uri, @@ -585,26 +625,34 @@ def _validate_cfg(icl_cfg: DictConfig): cot_delimiter=icl_cfg.get('cot_delimiter', ''), generation_kwargs=icl_cfg.get('generation_kwargs', {}), early_stopping_criteria=early_stopping_criteria, - do_normalization=icl_cfg.get('do_normalization', True)) + do_normalization=icl_cfg.get('do_normalization', True), + ) if hasattr( - icl_cfg, - 'has_categories') and icl_cfg.has_categories and isinstance( - dataloaders, dict): + icl_cfg, + 'has_categories', + ) and icl_cfg.has_categories and isinstance(dataloaders, dict): for category in dataloaders.keys(): logger_keys.extend([ f'metrics/{label}/{category}/{m}' for m in metric_names ]) evaluators.append( - Evaluator(label=f'{label}/{category}', - dataloader=dataloaders[category], - metric_names=metric_names),) + Evaluator( + label=f'{label}/{category}', + dataloader=dataloaders[category], + metric_names=metric_names, + ), + ) else: - logger_keys.extend( - [f'metrics/{label}/{m}' for m in metric_names]) + logger_keys.extend([ + f'metrics/{label}/{m}' for m in metric_names + ]) evaluators.append( - Evaluator(label=label, - dataloader=dataloaders, - metric_names=metric_names, - subset_num_batches=icl_subset_num_batches)) + Evaluator( + label=label, + dataloader=dataloaders, + metric_names=metric_names, + subset_num_batches=icl_subset_num_batches, + ), + ) return evaluators, logger_keys diff --git a/llmfoundry/utils/checkpoint_conversion_helpers.py b/llmfoundry/utils/checkpoint_conversion_helpers.py index 47a9750399..905afd6edb 100644 --- a/llmfoundry/utils/checkpoint_conversion_helpers.py +++ b/llmfoundry/utils/checkpoint_conversion_helpers.py @@ -18,8 +18,11 @@ from typing import Any, Dict, Optional, Tuple, Union import numpy as np -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import ( + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) log = logging.getLogger(__name__) @@ -47,12 +50,12 @@ def get_hf_tokenizer_from_composer_state_dict( ) -> Optional[PreTrainedTokenizer]: if 'state' not in state_dict: raise RuntimeError( - 'Unexpected composer state dictionary. Did you pass in a full composer checkpoint?' + 'Unexpected composer state dictionary. Did you pass in a full composer checkpoint?', ) if 'integrations' not in state_dict[ - 'state'] or 'huggingface' not in state_dict['state']['integrations']: + 'state'] or 'huggingface' not in state_dict['state']['integrations']: raise RuntimeError( - 'Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!' + 'Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!', ) hf_tokenizer_state = state_dict['state']['integrations']['huggingface'][ 'tokenizer'] @@ -60,9 +63,12 @@ def get_hf_tokenizer_from_composer_state_dict( if hf_tokenizer_state != {}: if tokenizer_save_dir is None: unique_suffix = ''.join( - random.choices(string.ascii_letters + string.digits, k=6)) + random.choices(string.ascii_letters + string.digits, k=6), + ) tokenizer_save_dir = os.path.join( - os.getcwd(), f'tokenizer-save-dir-{unique_suffix}') + os.getcwd(), + f'tokenizer-save-dir-{unique_suffix}', + ) os.makedirs(tokenizer_save_dir, exist_ok=True) for filename, saved_content in hf_tokenizer_state.items(): @@ -91,7 +97,7 @@ def get_hf_tokenizer_from_composer_state_dict( import sentencepiece as spm except ImportError as e: raise ImportError( - 'Your tokenizer uses `sentencepiece`. Please install `sentencepiece` to load it.' + 'Your tokenizer uses `sentencepiece`. Please install `sentencepiece` to load it.', ) from e s = spm.SentencePieceProcessor() @@ -99,8 +105,10 @@ def get_hf_tokenizer_from_composer_state_dict( with open(tokenizer_file_path, 'wb') as _tmp_file: _tmp_file.write(s.serialized_model_proto()) - hf_tokenizer = load_tokenizer(tokenizer_save_dir, - trust_remote_code=trust_remote_code) + hf_tokenizer = load_tokenizer( + tokenizer_save_dir, + trust_remote_code=trust_remote_code, + ) # remove 'name_or_path' hf_tokenizer.name_or_path = '' @@ -110,22 +118,29 @@ def get_hf_tokenizer_from_composer_state_dict( def load_tokenizer( - tokenizer_save_dir: str, trust_remote_code: bool + tokenizer_save_dir: str, + trust_remote_code: bool, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: try: return AutoTokenizer.from_pretrained( - tokenizer_save_dir, trust_remote_code=trust_remote_code) + tokenizer_save_dir, + trust_remote_code=trust_remote_code, + ) except ValueError as e: raise ValueError( f'Got error while loading tokenizer with trust_remote_code={trust_remote_code}: {e}. ' + 'If accessing a tokenizer defined outside of the transformers module,' - + ' please use --trust_remote_code.') + + ' please use --trust_remote_code.', + ) -def _write_zero_bias(weight_name: str, weight_file_path: str, - bias_shape: Union[Tuple[int, ...], - int], np_data_type: np.dtype) -> None: +def _write_zero_bias( + weight_name: str, + weight_file_path: str, + bias_shape: Union[Tuple[int, ...], int], + np_data_type: np.dtype, +) -> None: """Write zeros for bias when converting MPT to FasterTransformer weights. MPT model might not have bias while FT expects bias. @@ -138,7 +153,7 @@ def _write_zero_bias(weight_name: str, weight_file_path: str, """ if 'weight' not in weight_file_path: raise RuntimeError( - f'Cannot write zero bias for {weight_name}. Input is not a weight tensor' + f'Cannot write zero bias for {weight_name}. Input is not a weight tensor', ) log.debug(f'zero bias for weight: {weight_name}') bias_file_path = weight_file_path.replace('.weight', '.bias') @@ -146,10 +161,14 @@ def _write_zero_bias(weight_name: str, weight_file_path: str, bias.tofile(bias_file_path) -def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, - tensor_name: str, config: Dict[str, Any], - data: np.ndarray, - np_weight_data_type: np.dtype) -> None: +def _convert_weight_to_ft_each( + save_dir: str, + infer_gpu_num: int, + tensor_name: str, + config: Dict[str, Any], + data: np.ndarray, + np_weight_data_type: np.dtype, +) -> None: """Convert each MPT weight to a FasterTransformer compatible format. Args: @@ -170,14 +189,18 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, save_path = os.path.join(save_dir, f'model.{tensor_name}.bin') data.tofile(save_path) if 'weight' in tensor_name and config['no_bias']: - _write_zero_bias(tensor_name, save_path, data.shape[-1], - np_weight_data_type - ) # pyright: ignore [reportGeneralTypeIssues] + _write_zero_bias( + tensor_name, + save_path, + data.shape[-1], + np_weight_data_type, + ) # pyright: ignore [reportGeneralTypeIssues] elif tensor_name.find('attention.dense.weight') != -1: assert data.shape == ( config['d_model'], - config['d_model']), f'unexpected dim for {tensor_name}' + config['d_model'], + ), f'unexpected dim for {tensor_name}' # nn.Linear weights are transposed data = data.T split_vals = np.split(data, infer_gpu_num, axis=0) @@ -185,16 +208,22 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') split_vals[j].tofile(save_path) if config['no_bias']: - fake_weight_path = os.path.join(save_dir, - f'model.{tensor_name}.bin') - _write_zero_bias(tensor_name, fake_weight_path, data.shape[-1], - np_weight_data_type - ) # pyright: ignore [reportGeneralTypeIssues] + fake_weight_path = os.path.join( + save_dir, + f'model.{tensor_name}.bin', + ) + _write_zero_bias( + tensor_name, + fake_weight_path, + data.shape[-1], + np_weight_data_type, + ) # pyright: ignore [reportGeneralTypeIssues] elif tensor_name.find('mlp.dense_4h_to_h.weight') != -1: assert data.shape == ( - config['d_model'], config['expansion_ratio'] * - config['d_model']), f'unexpected dim for {tensor_name}' + config['d_model'], + config['expansion_ratio'] * config['d_model'], + ), f'unexpected dim for {tensor_name}' # nn.Linear weights are transposed data = data.T split_vals = np.split(data, infer_gpu_num, axis=0) @@ -202,16 +231,22 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') split_vals[j].tofile(save_path) if config['no_bias']: - fake_weight_path = os.path.join(save_dir, - f'model.{tensor_name}.bin') - _write_zero_bias(tensor_name, fake_weight_path, data.shape[-1], - np_weight_data_type - ) # pyright: ignore [reportGeneralTypeIssues] + fake_weight_path = os.path.join( + save_dir, + f'model.{tensor_name}.bin', + ) + _write_zero_bias( + tensor_name, + fake_weight_path, + data.shape[-1], + np_weight_data_type, + ) # pyright: ignore [reportGeneralTypeIssues] elif tensor_name.find('mlp.dense_h_to_4h.weight') != -1: assert data.shape == ( config['expansion_ratio'] * config['d_model'], - config['d_model']), f'unexpected dim for {tensor_name}' + config['d_model'], + ), f'unexpected dim for {tensor_name}' # nn.Linear weights are transposed data = data.T @@ -220,14 +255,17 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') split_vals[j].tofile(save_path) if config['no_bias']: - _write_zero_bias(tensor_name, save_path, - split_vals[j].shape[-1], np_weight_data_type - ) # pyright: ignore [reportGeneralTypeIssues] + _write_zero_bias( + tensor_name, + save_path, + split_vals[j].shape[-1], + np_weight_data_type, + ) # pyright: ignore [reportGeneralTypeIssues] elif tensor_name.find('mlp.dense_h_to_4h.bias') != -1: assert data.shape == ( - config['expansion_ratio'] * - config['d_model'],), f'unexpected dim for {tensor_name}' + config['expansion_ratio'] * config['d_model'], + ), f'unexpected dim for {tensor_name}' split_vals = np.split(data, infer_gpu_num, axis=-1) for j in range(infer_gpu_num): save_path = os.path.join(save_dir + f'model.{tensor_name}.{j}.bin') @@ -235,7 +273,8 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, elif tensor_name.find('attention.query_key_value.bias') != -1: assert data.shape == ( - 3 * config['d_model'],), f'unexpected dim for {tensor_name}' + 3 * config['d_model'], + ), f'unexpected dim for {tensor_name}' data = data.reshape(3, config['d_model']) @@ -248,7 +287,8 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, elif tensor_name.find('attention.query_key_value.weight') != -1: assert data.shape == ( 3 * config['d_model'], - config['d_model']), f'unexpected dim for {tensor_name}' + config['d_model'], + ), f'unexpected dim for {tensor_name}' # nn.Linear weights are transposed data = data.T @@ -259,20 +299,24 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') split_vals[j].tofile(save_path) if config['no_bias']: - _write_zero_bias(tensor_name, save_path, - (3, split_vals[j].shape[-1]), - np_weight_data_type - ) # pyright: ignore [reportGeneralTypeIssues] + _write_zero_bias( + tensor_name, + save_path, + (3, split_vals[j].shape[-1]), + np_weight_data_type, + ) # pyright: ignore [reportGeneralTypeIssues] else: raise RuntimeError(f'Tensor with name {tensor_name} is not handled') -def convert_and_save_ft_weights(named_params: dict, - config: dict, - infer_gpu_num: int = 1, - weight_data_type: str = 'fp32', - save_dir: str = '') -> None: +def convert_and_save_ft_weights( + named_params: dict, + config: dict, + infer_gpu_num: int = 1, + weight_data_type: str = 'fp32', + save_dir: str = '', +) -> None: """Convert a Composer MPT checkpoint to a FasterTransformer format. Args: @@ -311,44 +355,52 @@ def convert_and_save_ft_weights(named_params: dict, if name == 'transformer.wpe.weight': assert data.shape == ( config['max_seq_len'], - config['d_model']), f'unexpected dim for {name}' + config['d_model'], + ), f'unexpected dim for {name}' data.tofile(os.path.join(save_dir, 'model.wpe.bin')) elif name == 'transformer.wte.weight': assert data.shape == ( config['vocab_size'], - config['d_model']), f'unexpected dim for {name}' + config['d_model'], + ), f'unexpected dim for {name}' data.tofile(os.path.join(save_dir, 'model.wte.bin')) elif name == 'transformer.norm_f.bias': assert data.shape == ( - config['d_model'],), f'unexpected dim for {name}' - data.tofile(os.path.join(save_dir, - 'model.final_layernorm.bias.bin')) + config['d_model'], + ), f'unexpected dim for {name}' + data.tofile( + os.path.join(save_dir, 'model.final_layernorm.bias.bin'), + ) elif name == 'transformer.norm_f.weight': assert data.shape == ( - config['d_model'],), f'unexpected dim for {name}' - save_path = os.path.join(save_dir, - 'model.final_layernorm.weight.bin') + config['d_model'], + ), f'unexpected dim for {name}' + save_path = os.path.join( + save_dir, + 'model.final_layernorm.weight.bin', + ) data.tofile(save_path) if config['no_bias']: _write_zero_bias( name, save_path, data.shape[-1], - np_weight_data_type # pyright: ignore [reportGeneralTypeIssues] + np_weight_data_type, # pyright: ignore [reportGeneralTypeIssues] ) elif name == 'transformer.lm_head.weight': data.tofile(os.path.join(save_dir, 'model.lm_head.weight.bin')) else: for mpt_pattern, ft_pattern in param_remapping.items(): if name.find(mpt_pattern) != -1: - new_name = name.replace('transformer.blocks.', - 'layers.').replace( - mpt_pattern, ft_pattern) + new_name = name.replace( + 'transformer.blocks.', + 'layers.', + ).replace(mpt_pattern, ft_pattern) _convert_weight_to_ft_each( save_dir, infer_gpu_num, new_name, config, data, - np_weight_data_type # pyright: ignore [reportGeneralTypeIssues] + np_weight_data_type, # pyright: ignore [reportGeneralTypeIssues] ) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 0235462a7f..0b05214565 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -27,11 +27,13 @@ ] -def pop_config(cfg: DictConfig, - key: str, - must_exist: bool = True, - default_value: Any = None, - convert: bool = False) -> Any: +def pop_config( + cfg: DictConfig, + key: str, + must_exist: bool = True, + default_value: Any = None, + convert: bool = False, +) -> Any: """Pop a value from the main config file and return it. If the key does not exist, return the default_value or raise a RuntimeError @@ -40,32 +42,34 @@ def pop_config(cfg: DictConfig, """ value = cfg.pop(key, None) if value is not None and convert: - if not isinstance(value, DictConfig) and not isinstance( - value, ListConfig): + if not isinstance(value, + DictConfig) and not isinstance(value, ListConfig): raise ValueError( f'The key {key} has a value of type {type(value)} that cannot be \ - converted to a dict or list. Please check your yaml.' + converted to a dict or list. Please check your yaml.', ) return om.to_container(value) elif value is not None: return value elif must_exist: raise NameError( - f'The {key} parameter is missing and must exist for execution. Please check your yaml.' + f'The {key} parameter is missing and must exist for execution. Please check your yaml.', ) else: return default_value def calculate_batch_size_info( - global_batch_size: int, device_microbatch_size: Union[int, Literal['auto']] + global_batch_size: int, + device_microbatch_size: Union[int, Literal['auto']], ) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]: if global_batch_size % dist.get_world_size() != 0: raise ValueError( f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' - + f'to be divisible by world size, {dist.get_world_size()}.') + + f'to be divisible by world size, {dist.get_world_size()}.', + ) device_batch_size = global_batch_size // dist.get_world_size() if device_microbatch_size == 'auto': device_grad_accum = 'auto' @@ -73,11 +77,12 @@ def calculate_batch_size_info( if device_microbatch_size > device_batch_size: log.warn( f'device_microbatch_size > device_batch_size, ' + - f'will be reduced from {device_microbatch_size} -> {device_batch_size}.' + f'will be reduced from {device_microbatch_size} -> {device_batch_size}.', ) device_microbatch_size = device_batch_size - device_grad_accum = math.ceil(device_batch_size / - device_microbatch_size) + device_grad_accum = math.ceil( + device_batch_size / device_microbatch_size, + ) else: raise ValueError(f'Not sure how to parse {device_microbatch_size=}') @@ -87,7 +92,9 @@ def calculate_batch_size_info( # Coming soon: this conversion math will be done inside Composer Trainer def update_batch_size_info(cfg: DictConfig) -> DictConfig: device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( - cfg.global_train_batch_size, cfg.device_train_microbatch_size) + cfg.global_train_batch_size, + cfg.device_train_microbatch_size, + ) cfg.n_gpus = dist.get_world_size() cfg.device_train_batch_size = device_train_batch_size cfg.device_train_microbatch_size = device_train_microbatch_size @@ -120,12 +127,14 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): if fsdp_config is None: raise NotImplementedError( 'Using init_device `mixed` is only supported with FSDP. ' + - 'Please add a FSDP config.') + 'Please add a FSDP config.', + ) # Always set `sync_module_states` to True for mixed initialization if not fsdp_config.get('sync_module_states', False): warnings.warn(( 'Setting `sync_module_states = True` for FSDP. This is required ' - 'when using mixed initialization.')) + 'when using mixed initialization.' + )) fsdp_config['sync_module_states'] = True # Set defaults for mixed initialization @@ -134,19 +143,27 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Set ffn_config.device_mesh to fsdp_config.device_mesh if fsdp_config is not None and 'device_mesh' in fsdp_config and 'ffn_config' in model_cfg and model_cfg[ - 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: + 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: # Raise ValueError if not using device mesh with MoE expert parallelism if fsdp_config['device_mesh'] is None and model_cfg['ffn_config'].get( - 'moe_world_size', 1) > 1: + 'moe_world_size', + 1, + ) > 1: raise ValueError( - 'device_mesh must be specified in fsdp_config when using MoE with moe_world_size > 1.' + 'device_mesh must be specified in fsdp_config when using MoE with moe_world_size > 1.', ) model_cfg.ffn_config.device_mesh = fsdp_config['device_mesh'] # No mixed precision needed for weights when they're already 16 bits master_dtype = model_cfg.get('master_weights_dtype') - small_dtypes = ('bf16', 'fp16', 'float16', 'bfloat16', 'amp_fp16', - 'amp_bf16') + small_dtypes = ( + 'bf16', + 'fp16', + 'float16', + 'bfloat16', + 'amp_fp16', + 'amp_bf16', + ) if fsdp_config and master_dtype in small_dtypes: reduce_dtype = None buffer_dtype = None @@ -203,29 +220,43 @@ def _parse_source_dataset(cfg: DictConfig) -> List[Tuple[str, str, str]]: train_dataset = cfg.get('train_loader', {}).get('dataset', {}) train_split = train_dataset.get('split', None) train_source_path = cfg.get('source_dataset_train', None) - _process_data_source(train_source_path, train_dataset, train_split, 'train', - data_paths) + _process_data_source( + train_source_path, + train_dataset, + train_split, + 'train', + data_paths, + ) # Handle eval_loader which might be a list or a single dictionary eval_data_loaders = cfg.get('eval_loader', {}) if not isinstance(eval_data_loaders, ListConfig): - eval_data_loaders = [eval_data_loaders - ] # Normalize to list if it's a single dictionary + eval_data_loaders = [ + eval_data_loaders, + ] # Normalize to list if it's a single dictionary for eval_data_loader in eval_data_loaders: eval_dataset = eval_data_loader.get('dataset', {}) eval_split = eval_dataset.get('split', None) eval_source_path = cfg.get('source_dataset_eval', None) - _process_data_source(eval_source_path, eval_dataset, eval_split, 'eval', - data_paths) + _process_data_source( + eval_source_path, + eval_dataset, + eval_split, + 'eval', + data_paths, + ) return data_paths -def _process_data_source(source_dataset_path: Optional[str], - dataset: Dict[str, str], cfg_split: Optional[str], - true_split: str, data_paths: List[Tuple[str, str, - str]]): +def _process_data_source( + source_dataset_path: Optional[str], + dataset: Dict[str, str], + cfg_split: Optional[str], + true_split: str, + data_paths: List[Tuple[str, str, str]], +): """Add a data source by mutating data_paths. Given various dataset attributes, attempt to determine what type of dataset is being added, and parse @@ -244,7 +275,8 @@ def _process_data_source(source_dataset_path: Optional[str], # Check for UC volume elif source_dataset_path and source_dataset_path.startswith('dbfs:'): data_paths.append( - ('uc_volume', source_dataset_path[len('dbfs:'):], true_split)) + ('uc_volume', source_dataset_path[len('dbfs:'):], true_split), + ) # Check for HF path elif 'hf_name' in dataset: hf_path = dataset['hf_name'] @@ -262,7 +294,9 @@ def _process_data_source(source_dataset_path: Optional[str], backend, _, _ = parse_uri(remote_path) if backend: remote_path = os.path.join( - remote_path, f'{cfg_split}/') if cfg_split else remote_path + remote_path, + f'{cfg_split}/', + ) if cfg_split else remote_path data_paths.append((backend, remote_path, true_split)) else: data_paths.append(('local', remote_path, true_split)) @@ -304,8 +338,10 @@ def _log_dataset_uri(cfg: DictConfig) -> None: source = source_class(url=path) else: log.info( - f'{dataset_type} unknown, defaulting to http dataset source') + f'{dataset_type} unknown, defaulting to http dataset source', + ) source = mlflow.data.http_dataset_source.HTTPDatasetSource(url=path) mlflow.log_input( - mlflow.data.meta_dataset.MetaDataset(source, name=split)) + mlflow.data.meta_dataset.MetaDataset(source, name=split), + ) diff --git a/llmfoundry/utils/data_prep_utils.py b/llmfoundry/utils/data_prep_utils.py index 058e73b393..b5b606a57f 100644 --- a/llmfoundry/utils/data_prep_utils.py +++ b/llmfoundry/utils/data_prep_utils.py @@ -106,11 +106,15 @@ def __iter__(self): # Download objects if remote path. if self.object_store is not None: - output_filename = os.path.join(self.output_folder, - object_name.strip('/')) - self.object_store.download_object(object_name=object_name, - filename=output_filename, - overwrite=True) + output_filename = os.path.join( + self.output_folder, + object_name.strip('/'), + ) + self.object_store.download_object( + object_name=object_name, + filename=output_filename, + overwrite=True, + ) with open(output_filename) as _txt_file: txt = _txt_file.read() diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index ba34b29be3..8e9e46a1cf 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -49,9 +49,15 @@ def __init__(self) -> None: class NotEnoughDatasetSamplesError(ValueError): """Error thrown when there is not enough data to train a model.""" - def __init__(self, dataset_name: str, split: str, - dataloader_batch_size: int, world_size: int, - full_dataset_size: int, minimum_dataset_size: int) -> None: + def __init__( + self, + dataset_name: str, + split: str, + dataloader_batch_size: int, + world_size: int, + full_dataset_size: int, + minimum_dataset_size: int, + ) -> None: self.dataset_name = dataset_name self.split = split self.dataloader_batch_size = dataloader_batch_size @@ -64,7 +70,8 @@ def __init__(self, dataset_name: str, split: str, f'is {minimum_dataset_size} because you are running on {world_size} gpus and ' + f'your per device batch size is {dataloader_batch_size}. Please increase the number ' - + f'of samples in your dataset to at least {minimum_dataset_size}.') + + f'of samples in your dataset to at least {minimum_dataset_size}.' + ) super().__init__(message) @@ -172,7 +179,8 @@ def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None: self.valid_extensions = valid_extensions message = ( f'safe_load is set to True. No data files with safe extensions {valid_extensions} ' - + f'found for dataset at local path {dataset_name}.') + + f'found for dataset at local path {dataset_name}.' + ) super().__init__(message) diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index 3903a9bed3..b4eec89cbd 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -22,7 +22,9 @@ def visit(self, node: ast.AST) -> Optional[ast.AST]: def convert_to_relative_import( - module_name: str, original_parent_module_name: Optional[str]) -> str: + module_name: str, + original_parent_module_name: Optional[str], +) -> str: parts = module_name.split('.') if parts[-1] == original_parent_module_name: return '.' @@ -88,28 +90,38 @@ def process_file( for node in ast.walk(tree): # Remove any imports matching the remove_imports_prefix if isinstance( - node, - ast.ImportFrom) and node.module is not None and _remove_import( - node, remove_imports_prefix): + node, + ast.ImportFrom, + ) and node.module is not None and _remove_import( + node, + remove_imports_prefix, + ): nodes_to_remove.append(node) # Convert any (remaining) imports matching the flatten_imports_prefix # to relative imports - elif (isinstance(node, ast.ImportFrom) and node.module is not None and - _flatten_import(node, flatten_imports_prefix)): + elif ( + isinstance(node, ast.ImportFrom) and node.module is not None and + _flatten_import(node, flatten_imports_prefix) + ): module_path = find_module_file(node.module) - node.module = convert_to_relative_import(node.module, - parent_module_name) + node.module = convert_to_relative_import( + node.module, + parent_module_name, + ) # Recursively process any llmfoundry files new_files_to_process.append(module_path) # Remove the Composer* class - elif (isinstance(node, ast.ClassDef) and - node.name.startswith('Composer')): + elif ( + isinstance(node, ast.ClassDef) and node.name.startswith('Composer') + ): nodes_to_remove.append(node) # Remove the __all__ declaration in any __init__.py files, whose # enclosing module will be converted to a single file of the same name - elif (isinstance(node, ast.Assign) and len(node.targets) == 1 and - isinstance(node.targets[0], ast.Name) and - node.targets[0].id == '__all__'): + elif ( + isinstance(node, ast.Assign) and len(node.targets) == 1 and + isinstance(node.targets[0], ast.Name) and + node.targets[0].id == '__all__' + ): nodes_to_remove.append(node) transformer = DeleteSpecificNodes(nodes_to_remove) @@ -130,10 +142,13 @@ def process_file( def edit_files_for_hf_compatibility( folder: str, flatten_imports_prefix: Sequence[str] = ('llmfoundry',), - remove_imports_prefix: Sequence[str] = ('composer', 'omegaconf', - 'llmfoundry.metrics', - 'llmfoundry.eval', - 'llmfoundry.utils.builders') + remove_imports_prefix: Sequence[str] = ( + 'composer', + 'omegaconf', + 'llmfoundry.metrics', + 'llmfoundry.eval', + 'llmfoundry.utils.builders', + ), ) -> None: """Edit files to be compatible with Hugging Face Hub. diff --git a/llmfoundry/utils/logging_utils.py b/llmfoundry/utils/logging_utils.py index f6c930beab..c13b87b701 100644 --- a/llmfoundry/utils/logging_utils.py +++ b/llmfoundry/utils/logging_utils.py @@ -5,8 +5,10 @@ import os from composer.loggers import MosaicMLLogger -from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR, - MOSAICML_PLATFORM_ENV_VAR) +from composer.loggers.mosaicml_logger import ( + MOSAICML_ACCESS_TOKEN_ENV_VAR, + MOSAICML_PLATFORM_ENV_VAR, +) __all__ = [ 'SpecificWarningFilter', diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index 3707da3883..c11a47929f 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -49,10 +49,13 @@ ] -@tenacity.retry(retry=tenacity.retry_if_not_exception_type( - (ValueError, hf_hub.utils.RepositoryNotFoundError)), - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(min=1, max=10)) +@tenacity.retry( + retry=tenacity.retry_if_not_exception_type( + (ValueError, hf_hub.utils.RepositoryNotFoundError), + ), + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(min=1, max=10), +) def download_from_hf_hub( model: str, save_dir: str, @@ -83,20 +86,23 @@ def download_from_hf_hub( # Ignore TensorFlow, TensorFlow 2, and Flax weights as they are not supported by Composer. ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS) - safetensors_available = (SAFE_WEIGHTS_NAME in repo_files or - SAFE_WEIGHTS_INDEX_NAME in repo_files) - pytorch_available = (PYTORCH_WEIGHTS_NAME in repo_files or - PYTORCH_WEIGHTS_INDEX_NAME in repo_files) + safetensors_available = ( + SAFE_WEIGHTS_NAME in repo_files or SAFE_WEIGHTS_INDEX_NAME in repo_files + ) + pytorch_available = ( + PYTORCH_WEIGHTS_NAME in repo_files or + PYTORCH_WEIGHTS_INDEX_NAME in repo_files + ) if safetensors_available and pytorch_available: if prefer_safetensors: log.info( - 'Safetensors available and preferred. Excluding pytorch weights.' + 'Safetensors available and preferred. Excluding pytorch weights.', ) ignore_patterns.append(PYTORCH_WEIGHTS_PATTERN) else: log.info( - 'Pytorch available and preferred. Excluding safetensors weights.' + 'Pytorch available and preferred. Excluding safetensors weights.', ) ignore_patterns.append(SAFE_WEIGHTS_PATTERN) elif safetensors_available: @@ -106,21 +112,23 @@ def download_from_hf_hub( else: raise ValueError( f'No supported model weights found in repo {model}.' + - ' Please make sure the repo contains either safetensors or pytorch weights.' + ' Please make sure the repo contains either safetensors or pytorch weights.', ) allow_patterns = TOKENIZER_FILES if tokenizer_only else None download_start = time.time() - hf_hub.snapshot_download(model, - local_dir=save_dir, - local_dir_use_symlinks=False, - ignore_patterns=ignore_patterns, - allow_patterns=allow_patterns, - token=token) + hf_hub.snapshot_download( + model, + local_dir=save_dir, + local_dir_use_symlinks=False, + ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns, + token=token, + ) download_duration = time.time() - download_start log.info( - f'Downloaded model {model} from Hugging Face Hub in {download_duration} seconds' + f'Downloaded model {model} from Hugging Face Hub in {download_duration} seconds', ) @@ -168,15 +176,15 @@ def _recursive_download( if response.status_code == HTTPStatus.UNAUTHORIZED: raise PermissionError( - f'Not authorized to download file from {url}. Received status code {response.status_code}. ' + f'Not authorized to download file from {url}. Received status code {response.status_code}. ', ) elif response.status_code == HTTPStatus.NOT_FOUND: raise ValueError( - f'Could not find file at {url}. Received status code {response.status_code}' + f'Could not find file at {url}. Received status code {response.status_code}', ) elif response.status_code != HTTPStatus.OK: raise RuntimeError( - f'Could not download file from {url}. Received unexpected status code {response.status_code}' + f'Could not download file from {url}. Received unexpected status code {response.status_code}', ) # Assume that the URL points to a file if it does not end with a slash. @@ -197,17 +205,20 @@ def _recursive_download( child_links = _extract_links_from_html(response.content.decode()) print(child_links) for child_link in child_links: - _recursive_download(session, - base_url, - urljoin(path, child_link), - save_dir, - ignore_cert=ignore_cert) + _recursive_download( + session, + base_url, + urljoin(path, child_link), + save_dir, + ignore_cert=ignore_cert, + ) -@tenacity.retry(retry=tenacity.retry_if_not_exception_type( - (PermissionError, ValueError)), - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(min=1, max=10)) +@tenacity.retry( + retry=tenacity.retry_if_not_exception_type((PermissionError, ValueError)), + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(min=1, max=10), +) def download_from_http_fileserver( url: str, save_dir: str, @@ -228,19 +239,23 @@ def download_from_http_fileserver( if ignore_cert: warnings.simplefilter('ignore', category=InsecureRequestWarning) - _recursive_download(session, - url, - '', - save_dir, - ignore_cert=ignore_cert) + _recursive_download( + session, + url, + '', + save_dir, + ignore_cert=ignore_cert, + ) -def download_from_oras(model: str, - config_file: str, - credentials_dir: str, - save_dir: str, - tokenizer_only: bool = False, - concurrency: int = 10): +def download_from_oras( + model: str, + config_file: str, + credentials_dir: str, + save_dir: str, + tokenizer_only: bool = False, + concurrency: int = 10, +): """Download from an OCI-compliant registry using oras. Args: @@ -254,7 +269,7 @@ def download_from_oras(model: str, """ if shutil.which(ORAS_CLI) is None: raise Exception( - f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ' + f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ', ) def _read_secrets_file(secret_file_path: str,): @@ -263,12 +278,14 @@ def _read_secrets_file(secret_file_path: str,): return f.read().strip() except Exception as error: raise ValueError( - f'secrets file {secret_file_path} failed to be read') from error + f'secrets file {secret_file_path} failed to be read', + ) from error secrets = {} for secret in ['username', 'password', 'registry']: secrets[secret] = _read_secrets_file( - os.path.join(credentials_dir, secret)) + os.path.join(credentials_dir, secret), + ) with open(config_file, 'r', encoding='utf-8') as f: configs = yaml.safe_load(f.read()) @@ -277,8 +294,10 @@ def _read_secrets_file(secret_file_path: str,): path = configs[config_type][model] registry = secrets['registry'] - def get_oras_cmd(username: Optional[str] = None, - password: Optional[str] = None): + def get_oras_cmd( + username: Optional[str] = None, + password: Optional[str] = None, + ): cmd = [ ORAS_CLI, 'pull', @@ -298,11 +317,17 @@ def get_oras_cmd(username: Optional[str] = None, cmd_without_creds = get_oras_cmd() log.info(f'CMD for oras cli to run: {" ".join(cmd_without_creds)}') - cmd_to_run = get_oras_cmd(username=secrets['username'], - password=secrets['password']) + cmd_to_run = get_oras_cmd( + username=secrets['username'], + password=secrets['password'], + ) try: subprocess.run(cmd_to_run, check=True) except subprocess.CalledProcessError as e: # Intercept the error and replace the cmd, which may have sensitive info. - raise subprocess.CalledProcessError(e.returncode, cmd_without_creds, - e.output, e.stderr) + raise subprocess.CalledProcessError( + e.returncode, + cmd_without_creds, + e.output, + e.stderr, + ) diff --git a/llmfoundry/utils/mosaicml_logger_utils.py b/llmfoundry/utils/mosaicml_logger_utils.py index b4a40821ed..a65ebd9454 100644 --- a/llmfoundry/utils/mosaicml_logger_utils.py +++ b/llmfoundry/utils/mosaicml_logger_utils.py @@ -6,8 +6,10 @@ from composer.loggers import MosaicMLLogger from composer.loggers.logger_destination import LoggerDestination -from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR, - MOSAICML_PLATFORM_ENV_VAR) +from composer.loggers.mosaicml_logger import ( + MOSAICML_ACCESS_TOKEN_ENV_VAR, + MOSAICML_PLATFORM_ENV_VAR, +) from omegaconf import DictConfig, ListConfig __all__ = [ @@ -37,25 +39,31 @@ def maybe_create_mosaicml_logger() -> Optional[MosaicMLLogger]: def find_mosaicml_logger( - loggers: List[LoggerDestination]) -> Optional[MosaicMLLogger]: + loggers: List[LoggerDestination], +) -> Optional[MosaicMLLogger]: """Returns the first MosaicMLLogger from a list, and None otherwise.""" return next( (logger for logger in loggers if isinstance(logger, MosaicMLLogger)), - None) + None, + ) -def log_eval_analytics(mosaicml_logger: MosaicMLLogger, - model_configs: ListConfig, icl_tasks: Union[str, - ListConfig], - eval_gauntlet_config: Optional[Union[str, DictConfig]]): +def log_eval_analytics( + mosaicml_logger: MosaicMLLogger, + model_configs: ListConfig, + icl_tasks: Union[str, ListConfig], + eval_gauntlet_config: Optional[Union[str, DictConfig]], +): """Logs analytics for runs using the `eval.py` script.""" metrics: Dict[str, Any] = { 'llmfoundry/script': 'eval', } metrics['llmfoundry/gauntlet_configured'] = eval_gauntlet_config is not None - metrics['llmfoundry/icl_configured'] = isinstance(icl_tasks, - str) or len(icl_tasks) > 0 + metrics['llmfoundry/icl_configured'] = isinstance( + icl_tasks, + str, + ) or len(icl_tasks) > 0 metrics['llmfoundry/model_configs'] = [] for model_config in model_configs: @@ -67,21 +75,24 @@ def log_eval_analytics(mosaicml_logger: MosaicMLLogger, if len(model_config_data) > 0: metrics['llmfoundry/model_configs'].append( - json.dumps(model_config_data, sort_keys=True)) + json.dumps(model_config_data, sort_keys=True), + ) mosaicml_logger.log_metrics(metrics) mosaicml_logger._flush_metadata(force_flush=True) -def log_train_analytics(mosaicml_logger: MosaicMLLogger, - model_config: DictConfig, - train_loader_config: DictConfig, - eval_loader_config: Optional[Union[DictConfig, - ListConfig]], - callback_configs: Optional[DictConfig], - tokenizer_name: str, load_path: Optional[str], - icl_tasks_config: Optional[Union[ListConfig, str]], - eval_gauntlet: Optional[Union[DictConfig, str]]): +def log_train_analytics( + mosaicml_logger: MosaicMLLogger, + model_config: DictConfig, + train_loader_config: DictConfig, + eval_loader_config: Optional[Union[DictConfig, ListConfig]], + callback_configs: Optional[DictConfig], + tokenizer_name: str, + load_path: Optional[str], + icl_tasks_config: Optional[Union[ListConfig, str]], + eval_gauntlet: Optional[Union[DictConfig, str]], +): """Logs analytics for runs using the `train.py` script.""" train_loader_dataset = train_loader_config.get('dataset', {}) metrics: Dict[str, Any] = { @@ -96,12 +107,16 @@ def log_train_analytics(mosaicml_logger: MosaicMLLogger, ] metrics['llmfoundry/gauntlet_configured'] = eval_gauntlet is not None - metrics['llmfoundry/icl_configured'] = (icl_tasks_config is not None and ( - (isinstance(icl_tasks_config, str) or len(icl_tasks_config) > 0))) + metrics['llmfoundry/icl_configured'] = ( + icl_tasks_config is not None and + ((isinstance(icl_tasks_config, str) or len(icl_tasks_config) > 0)) + ) if train_loader_dataset.get('hf_name', None) is not None: metrics['llmfoundry/train_dataset_hf_name'] = train_loader_dataset.get( - 'hf_name', None) + 'hf_name', + None, + ) if train_loader_config.get('name') == 'finetuning': metrics['llmfoundry/train_task_type'] = 'INSTRUCTION_FINETUNE' elif train_loader_config.get('name') == 'text': @@ -124,11 +139,13 @@ def log_train_analytics(mosaicml_logger: MosaicMLLogger, eval_loader_info['name'] = loader_config.get('name') if eval_loader_dataset.get('hf_name', None) is not None: eval_loader_info['dataset_hf_name'] = eval_loader_dataset.get( - 'hf_name') + 'hf_name', + ) # Log as a key-sorted JSON string, so that we can easily parse it in Spark / SQL metrics['llmfoundry/eval_loaders'].append( - json.dumps(eval_loader_info, sort_keys=True)) + json.dumps(eval_loader_info, sort_keys=True), + ) model_config_data = {} for key in _MODEL_KEYS_TO_LOG: diff --git a/llmfoundry/utils/prompt_files.py b/llmfoundry/utils/prompt_files.py index 38e8439f11..64e5de70a1 100644 --- a/llmfoundry/utils/prompt_files.py +++ b/llmfoundry/utils/prompt_files.py @@ -34,8 +34,10 @@ def load_prompts(prompts: List[str], return prompt_strings -def load_prompts_from_file(prompt_path: str, - prompt_delimiter: Optional[str] = None) -> List[str]: +def load_prompts_from_file( + prompt_path: str, + prompt_delimiter: Optional[str] = None, +) -> List[str]: """Load a set of prompts from a text fie. Args: @@ -53,7 +55,8 @@ def load_prompts_from_file(prompt_path: str, prompt_file_path = os.path.expanduser(prompt_file_path) if not os.path.isfile(prompt_file_path): raise FileNotFoundError( - f'{prompt_file_path=} does not match any existing files.') + f'{prompt_file_path=} does not match any existing files.', + ) with open(prompt_file_path, 'r') as f: prompt_string = f.read() diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 1604a8a91f..3ea7cc58a7 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -8,8 +8,17 @@ from contextlib import contextmanager from pathlib import Path from types import ModuleType -from typing import (Any, Callable, Dict, Generic, Optional, Sequence, Type, - TypeVar, Union) +from typing import ( + Any, + Callable, + Dict, + Generic, + Optional, + Sequence, + Type, + TypeVar, + Union, +) import catalogue @@ -32,10 +41,12 @@ class TypedRegistry(catalogue.Registry, Generic[T]): descriptions. """ - def __init__(self, - namespace: Sequence[str], - entry_points: bool = False, - description: str = '') -> None: + def __init__( + self, + namespace: Sequence[str], + entry_points: bool = False, + description: str = '', + ) -> None: super().__init__(namespace, entry_points=entry_points) self.description = description @@ -46,10 +57,12 @@ def __call__(self, name: str, func: Optional[T] = None) -> Callable[[T], T]: def register(self, name: str, *, func: Optional[T] = None) -> T: return super().register(name, func=func) - def register_class(self, - name: str, - *, - func: Optional[TypeBoundT] = None) -> TypeBoundT: + def register_class( + self, + name: str, + *, + func: Optional[TypeBoundT] = None, + ) -> TypeBoundT: return super().register(name, func=func) def get(self, name: str) -> T: @@ -88,9 +101,11 @@ def create_registry( if catalogue.check_exists(*namespace): raise catalogue.RegistryError(f'Namespace already exists: {namespace}') - return TypedRegistry[generic_type](namespace, - entry_points=entry_points, - description=description) + return TypedRegistry[generic_type]( + namespace, + entry_points=entry_points, + description=description, + ) def construct_from_registry( @@ -128,26 +143,27 @@ def construct_from_registry( if isinstance(pre_validation_function, type): if not issubclass(registered_constructor, pre_validation_function): raise ValueError( - f'Expected {name} to be of type {pre_validation_function}, but got {type(registered_constructor)}' + f'Expected {name} to be of type {pre_validation_function}, but got {type(registered_constructor)}', ) elif isinstance(pre_validation_function, Callable): pre_validation_function(registered_constructor) else: raise ValueError( - f'Expected pre_validation_function to be a callable or a type, but got {type(pre_validation_function)}' + f'Expected pre_validation_function to be a callable or a type, but got {type(pre_validation_function)}', ) # If it is a class, or a builder function, construct the class with kwargs # If it is a function, create a partial with kwargs if isinstance( - registered_constructor, - type) or callable(registered_constructor) and not partial_function: + registered_constructor, + type, + ) or callable(registered_constructor) and not partial_function: constructed_item = registered_constructor(**kwargs) elif callable(registered_constructor): constructed_item = functools.partial(registered_constructor, **kwargs) else: raise ValueError( - f'Expected {name} to be a class or function, but got {type(registered_constructor)}' + f'Expected {name} to be a class or function, but got {type(registered_constructor)}', ) if post_validation_function is not None: diff --git a/llmfoundry/utils/warnings.py b/llmfoundry/utils/warnings.py index fb0046f938..83b2d1a32a 100644 --- a/llmfoundry/utils/warnings.py +++ b/llmfoundry/utils/warnings.py @@ -33,8 +33,9 @@ class VersionedDeprecationWarning(UserWarning): """ def __init__(self, message: str, remove_version: str) -> None: - super().__init__(message + - f' It will be removed in version {remove_version}.') + super().__init__( + message + f' It will be removed in version {remove_version}.', + ) class ExperimentalWarning(Warning): @@ -46,7 +47,7 @@ class ExperimentalWarning(Warning): def __init__(self, feature_name: str) -> None: super().__init__( - f'{feature_name} is experimental and may change with future versions.' + f'{feature_name} is experimental and may change with future versions.', ) diff --git a/pyproject.toml b/pyproject.toml index 6bde062abb..53007cafaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,25 @@ build-backend = "setuptools.build_meta" multi_line_output = 0 line_length = 80 skip = [ "env", "wandb", "runs", "build", "node_modules" ] +include_trailing_comma = true +split_on_trailing_comma = true + +[tool.ruff.lint] +select = [ + "C4", + # TODO port pydocstyle + # "D", # pydocstyle + "LOG", + "PERF", + "PLE", + "COM812", +] +[tool.ruff] +exclude = [ + "build/**", + "docs/**", + "node_modules/**", +] # Coverage [tool.coverage.run] @@ -163,7 +182,7 @@ blank_line_before_nested_class_or_def = true # 'key1': 'value1', # 'key2': 'value2', # }) -coalesce_brackets = false +coalesce_brackets = true # The column limit. column_limit = 80 @@ -198,7 +217,7 @@ continuation_indent_width = 4 # start_ts=now()-timedelta(days=3), # end_ts=now(), # ) # <--- this bracket is dedented and on a separate line -dedent_closing_brackets = false +dedent_closing_brackets = true # Disable the heuristic which places each list element on a separate line # if the list is comma-terminated. @@ -369,7 +388,7 @@ split_all_top_level_comma_separated_values = false # Split before arguments if the argument list is terminated by a # comma. -split_arguments_when_comma_terminated = false +split_arguments_when_comma_terminated = true # Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@' # rather than after. diff --git a/scripts/data_prep/convert_dataset_hf.py b/scripts/data_prep/convert_dataset_hf.py index 856d299b64..d7aaa52193 100644 --- a/scripts/data_prep/convert_dataset_hf.py +++ b/scripts/data_prep/convert_dataset_hf.py @@ -30,17 +30,20 @@ def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( description= - 'Convert dataset into MDS format, optionally concatenating and tokenizing' + 'Convert dataset into MDS format, optionally concatenating and tokenizing', ) parser.add_argument('--dataset', type=str, required=True) - parser.add_argument('--data_subset', - type=str, - default=None, - help='E.g. "all" or "en"') + parser.add_argument( + '--data_subset', + type=str, + default=None, + help='E.g. "all" or "en"', + ) parser.add_argument( '--splits', nargs='+', - default=['train', 'train_small', 'val', 'val_small', 'val_xsmall']) + default=['train', 'train_small', 'val', 'val_small', 'val_xsmall'], + ) parser.add_argument('--out_root', type=str, required=True) parser.add_argument('--compression', type=str, default=None) @@ -48,7 +51,8 @@ def parse_args() -> Namespace: group.add_argument( '--concat_tokens', type=int, - help='Convert text to tokens and concatenate up to this many tokens') + help='Convert text to tokens and concatenate up to this many tokens', + ) parser.add_argument('--tokenizer', type=str, required=False, default=None) parser.add_argument('--tokenizer_kwargs', type=str, required=False) @@ -65,17 +69,20 @@ def parse_args() -> Namespace: parsed.tokenizer_kwargs = {} if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set( - parsed.splits))) > 0: + set(os.listdir(parsed.out_root)).intersection(set(parsed.splits)), + ) > 0: raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.' + f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.', ) # Make sure we have needed concat options - if (parsed.concat_tokens is not None and - isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None): + if ( + parsed.concat_tokens is not None and + isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None + ): parser.error( - 'When setting --concat_tokens, you must specify a --tokenizer') + 'When setting --concat_tokens, you must specify a --tokenizer', + ) # now that we have validated them, change BOS/EOS to strings if parsed.bos_text is None: @@ -100,97 +107,121 @@ class DatasetConstants: splits: Dict[str, DataSplitConstants] = field(default_factory=dict) def __iter__(self): - for _, v in self.splits.items(): + for v in self.splits.values(): yield v class TrainSmallConstants(DataSplitConstants): - def __init__(self, - hf_split: str = 'train', - folder_split: str = 'train_small', - raw_samples: int = 100000, - truncated_samples: int = 100000): + def __init__( + self, + hf_split: str = 'train', + folder_split: str = 'train_small', + raw_samples: int = 100000, + truncated_samples: int = 100000, + ): super().__init__(hf_split, folder_split, raw_samples, truncated_samples) class ValSmallConstants(DataSplitConstants): - def __init__(self, - hf_split: str = 'validation', - folder_split: str = 'val_small', - raw_samples: int = 10000, - truncated_samples: int = 10000): + def __init__( + self, + hf_split: str = 'validation', + folder_split: str = 'val_small', + raw_samples: int = 10000, + truncated_samples: int = 10000, + ): super().__init__(hf_split, folder_split, raw_samples, truncated_samples) class ValXSmallConstants(DataSplitConstants): - def __init__(self, - hf_split: str = 'validation', - folder_split: str = 'val_xsmall', - raw_samples: int = 3000, - truncated_samples: int = 3000): + def __init__( + self, + hf_split: str = 'validation', + folder_split: str = 'val_xsmall', + raw_samples: int = 3000, + truncated_samples: int = 3000, + ): super().__init__(hf_split, folder_split, raw_samples, truncated_samples) pileconstants = DatasetConstants( chars_per_sample=6212, # Computed over validation set - chars_per_token=4 # OpenAI estimate + chars_per_token=4, # OpenAI estimate +) +pileconstants.splits['train'] = DataSplitConstants( + hf_split='train', + folder_split='train', + raw_samples=210607728, + truncated_samples=None, ) -pileconstants.splits['train'] = DataSplitConstants(hf_split='train', - folder_split='train', - raw_samples=210607728, - truncated_samples=None) pileconstants.splits['train_small'] = DataSplitConstants( hf_split='train', folder_split='train_small', raw_samples=100000, - truncated_samples=100000) -pileconstants.splits['val'] = DataSplitConstants(hf_split='validation', - folder_split='val', - raw_samples=214670, - truncated_samples=None) -pileconstants.splits['val_small'] = DataSplitConstants(hf_split='validation', - folder_split='val_small', - raw_samples=10000, - truncated_samples=10000) + truncated_samples=100000, +) +pileconstants.splits['val'] = DataSplitConstants( + hf_split='validation', + folder_split='val', + raw_samples=214670, + truncated_samples=None, +) +pileconstants.splits['val_small'] = DataSplitConstants( + hf_split='validation', + folder_split='val_small', + raw_samples=10000, + truncated_samples=10000, +) pileconstants.splits['val_xsmall'] = DataSplitConstants( hf_split='validation', folder_split='val_xsmall', raw_samples=3000, - truncated_samples=3000) + truncated_samples=3000, +) c4constants = DatasetConstants( chars_per_sample=2163, # Computed over validation set - chars_per_token=4 # OpenAI estimate + chars_per_token=4, # OpenAI estimate +) +c4constants.splits['train'] = DataSplitConstants( + hf_split='train', + folder_split='train', + raw_samples=364868892, + truncated_samples=None, ) -c4constants.splits['train'] = DataSplitConstants(hf_split='train', - folder_split='train', - raw_samples=364868892, - truncated_samples=None) c4constants.splits['train_small'] = DataSplitConstants( hf_split='train', folder_split='train_small', raw_samples=100000, - truncated_samples=100000) -c4constants.splits['val'] = DataSplitConstants(hf_split='validation', - folder_split='val', - raw_samples=364608, - truncated_samples=None) -c4constants.splits['val_small'] = DataSplitConstants(hf_split='validation', - folder_split='val_small', - raw_samples=10000, - truncated_samples=10000) -c4constants.splits['val_xsmall'] = DataSplitConstants(hf_split='validation', - folder_split='val_xsmall', - raw_samples=3000, - truncated_samples=3000) + truncated_samples=100000, +) +c4constants.splits['val'] = DataSplitConstants( + hf_split='validation', + folder_split='val', + raw_samples=364608, + truncated_samples=None, +) +c4constants.splits['val_small'] = DataSplitConstants( + hf_split='validation', + folder_split='val_small', + raw_samples=10000, + truncated_samples=10000, +) +c4constants.splits['val_xsmall'] = DataSplitConstants( + hf_split='validation', + folder_split='val_xsmall', + raw_samples=3000, + truncated_samples=3000, +) c4constants.splits['val_xxsmall'] = DataSplitConstants( hf_split='validation', folder_split='val_xxsmall', raw_samples=100, - truncated_samples=100) + truncated_samples=100, +) CONSTS = {'c4': c4constants, 'the_pile': pileconstants} @@ -223,41 +254,50 @@ def build_hf_dataset( Returns: An IterableDataset. """ - hf_dataset = hf_datasets.load_dataset(path=dataset_name, - name=data_subset, - split=split, - streaming=True) + hf_dataset = hf_datasets.load_dataset( + path=dataset_name, + name=data_subset, + split=split, + streaming=True, + ) if mode == ConcatMode.NO_CONCAT: dataset = NoConcatDataset(hf_dataset) else: if not isinstance(tokenizer, PreTrainedTokenizerBase): raise ValueError( - f'{tokenizer=} must be of type PreTrainedTokenizerBase') + f'{tokenizer=} must be of type PreTrainedTokenizerBase', + ) if max_length is None: raise ValueError(f'max_length must be set.') if bos_text + eos_text == '': test_tokens = tokenizer('test') if test_tokens['input_ids'][ - 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ - -1] != tokenizer.eos_token_id: + 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ + -1] != tokenizer.eos_token_id: tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' tok_error_msg += '--bos_text=<|endoftext|>.' raise ValueError(tok_error_msg) - dataset = ConcatTokensDataset(hf_dataset=hf_dataset, - tokenizer=tokenizer, - max_length=max_length, - bos_text=bos_text, - eos_text=eos_text, - no_wrap=no_wrap) + dataset = ConcatTokensDataset( + hf_dataset=hf_dataset, + tokenizer=tokenizer, + max_length=max_length, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + ) return dataset -def _est_progress_denominator(total_samples: int, chars_per_sample: int, - chars_per_token: int, mode: ConcatMode, - max_length: int): +def _est_progress_denominator( + total_samples: int, + chars_per_sample: int, + chars_per_token: int, + mode: ConcatMode, + max_length: int, +): est_tokens_per_sample = chars_per_sample // chars_per_token if mode == ConcatMode.NO_CONCAT: return total_samples @@ -265,8 +305,11 @@ def _est_progress_denominator(total_samples: int, chars_per_sample: int, return total_samples * est_tokens_per_sample // max_length -def build_dataloader(dataset: Dataset, batch_size: int, - num_workers: Optional[int]) -> DataLoader: +def build_dataloader( + dataset: Dataset, + batch_size: int, + num_workers: Optional[int], +) -> DataLoader: if num_workers is None: # Multiple workers is only supported on linux machines if 'linux' or 'macos' in platform.platform().lower(): @@ -278,8 +321,10 @@ def build_dataloader(dataset: Dataset, batch_size: int, # the aggregate device batch size # If not using workers, the torch DataLoader expects the default value for prefetch_factor, # which non-intuitively must be 2. - prefetch_factor = max(1, 2 * batch_size // - num_workers) if num_workers > 0 else 2 + prefetch_factor = max( + 1, + 2 * batch_size // num_workers, + ) if num_workers > 0 else 2 return DataLoader( dataset=dataset, @@ -291,8 +336,8 @@ def build_dataloader(dataset: Dataset, batch_size: int, def generate_samples( - loader: DataLoader, - truncate_num_samples: Optional[int] = None + loader: DataLoader, + truncate_num_samples: Optional[int] = None, ) -> Iterable[Dict[str, bytes]]: """Generator over samples of a dataloader. @@ -324,7 +369,7 @@ def main(args: Namespace) -> None: dataset_constants = CONSTS[args.dataset] except KeyError: raise ValueError( - f'Constants for dataset "{args.dataset}" not found. Currently only "the_pile" and "c4" are supported.' + f'Constants for dataset "{args.dataset}" not found. Currently only "the_pile" and "c4" are supported.', ) if args.concat_tokens is not None: @@ -352,20 +397,26 @@ def main(args: Namespace) -> None: continue # Get samples - dataset = build_hf_dataset(dataset_name=args.dataset, - data_subset=args.data_subset, - split=hf_split, - mode=mode, - max_length=args.concat_tokens, - bos_text=args.bos_text, - eos_text=args.eos_text, - no_wrap=args.no_wrap, - tokenizer=tokenizer) - loader = build_dataloader(dataset=dataset, - batch_size=512, - num_workers=args.num_workers) - samples = generate_samples(loader, - truncate_num_samples=truncate_num_samples) + dataset = build_hf_dataset( + dataset_name=args.dataset, + data_subset=args.data_subset, + split=hf_split, + mode=mode, + max_length=args.concat_tokens, + bos_text=args.bos_text, + eos_text=args.eos_text, + no_wrap=args.no_wrap, + tokenizer=tokenizer, + ) + loader = build_dataloader( + dataset=dataset, + batch_size=512, + num_workers=args.num_workers, + ) + samples = generate_samples( + loader, + truncate_num_samples=truncate_num_samples, + ) if expected_num_samples is not None: denominator = truncate_num_samples if truncate_num_samples is not None else _est_progress_denominator( @@ -381,15 +432,19 @@ def main(args: Namespace) -> None: # Write samples print(f'Converting {folder_split} to MDS format...') print( - f'Note: the progress bar is based on the dataset length before tokenization, and may finish at a value before 100%.' + f'Note: the progress bar is based on the dataset length before tokenization, and may finish at a value before 100%.', ) - with MDSWriter(columns=columns, - out=os.path.join(args.out_root, folder_split), - compression=args.compression) as out: + with MDSWriter( + columns=columns, + out=os.path.join(args.out_root, folder_split), + compression=args.compression, + ) as out: if denominator is not None: - for sample in tqdm(samples, - desc=folder_split, - total=denominator): + for sample in tqdm( + samples, + desc=folder_split, + total=denominator, + ): out.write(sample) else: for sample in tqdm(samples, desc=folder_split): diff --git a/scripts/data_prep/convert_dataset_json.py b/scripts/data_prep/convert_dataset_json.py index 54c0bfa814..fb117ddef3 100644 --- a/scripts/data_prep/convert_dataset_json.py +++ b/scripts/data_prep/convert_dataset_json.py @@ -26,7 +26,7 @@ def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( description= - 'Convert dataset into MDS format, optionally concatenating and tokenizing' + 'Convert dataset into MDS format, optionally concatenating and tokenizing', ) parser.add_argument('--path', type=str, required=True) parser.add_argument('--out_root', type=str, required=True) @@ -36,7 +36,8 @@ def parse_args() -> Namespace: group.add_argument( '--concat_tokens', type=int, - help='Convert text to tokens and concatenate up to this many tokens') + help='Convert text to tokens and concatenate up to this many tokens', + ) parser.add_argument('--split', type=str, default='train') parser.add_argument('--tokenizer', type=str, required=False, default=None) @@ -47,17 +48,20 @@ def parse_args() -> Namespace: parsed = parser.parse_args() if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set( - parsed.split))) > 0: + set(os.listdir(parsed.out_root)).intersection(set(parsed.split)), + ) > 0: raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.' + f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.', ) # Make sure we have needed concat options - if (parsed.concat_tokens is not None and - isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None): + if ( + parsed.concat_tokens is not None and + isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None + ): parser.error( - 'When setting --concat_tokens, you must specify a --tokenizer') + 'When setting --concat_tokens, you must specify a --tokenizer', + ) # now that we have validated them, change BOS/EOS to strings if parsed.bos_text is None: @@ -99,41 +103,46 @@ def build_hf_dataset( else: data_files = path - hf_dataset = hf_datasets.load_dataset('json', - data_files=data_files, - split=split) + hf_dataset = hf_datasets.load_dataset( + 'json', + data_files=data_files, + split=split, + ) if mode == ConcatMode.NO_CONCAT: dataset = NoConcatDataset(hf_dataset) else: if not isinstance(tokenizer, PreTrainedTokenizerBase): raise ValueError( - f'{tokenizer=} must be of type PreTrainedTokenizerBase') + f'{tokenizer=} must be of type PreTrainedTokenizerBase', + ) if max_length is None: raise ValueError(f'max_length must be set.') if bos_text + eos_text == '': test_tokens = tokenizer('test') if test_tokens['input_ids'][ - 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ - -1] != tokenizer.eos_token_id: + 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ + -1] != tokenizer.eos_token_id: tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' tok_error_msg += '--bos_text=<|endoftext|>.' raise ValueError(tok_error_msg) - dataset = ConcatTokensDataset(hf_dataset=hf_dataset, - tokenizer=tokenizer, - max_length=max_length, - bos_text=bos_text, - eos_text=eos_text, - no_wrap=no_wrap) + dataset = ConcatTokensDataset( + hf_dataset=hf_dataset, + tokenizer=tokenizer, + max_length=max_length, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + ) return dataset def generate_samples( - loader: DataLoader, - truncate_num_samples: Optional[int] = None + loader: DataLoader, + truncate_num_samples: Optional[int] = None, ) -> Iterable[Dict[str, bytes]]: """Generator over samples of a dataloader. @@ -173,26 +182,30 @@ def main(args: Namespace) -> None: columns = {'text': 'str'} # Get samples - dataset = build_hf_dataset(path=args.path, - split=args.split, - mode=mode, - max_length=args.concat_tokens, - bos_text=args.bos_text, - eos_text=args.eos_text, - no_wrap=args.no_wrap, - tokenizer=tokenizer) + dataset = build_hf_dataset( + path=args.path, + split=args.split, + mode=mode, + max_length=args.concat_tokens, + bos_text=args.bos_text, + eos_text=args.eos_text, + no_wrap=args.no_wrap, + tokenizer=tokenizer, + ) print('here') # Write samples print(f'Converting to MDS format...') print( - f'Note that the progress bar is based on the dataset length before tokenization.' + f'Note that the progress bar is based on the dataset length before tokenization.', ) print(f'It will finish at a value below 100% if tokenizing') - with MDSWriter(columns=columns, - out=os.path.join(args.out_root), - compression=args.compression) as out: + with MDSWriter( + columns=columns, + out=os.path.join(args.out_root), + compression=args.compression, + ) as out: for sample in tqdm(dataset): out.write(sample) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index aefafdb49a..d871761803 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -34,9 +34,11 @@ from pyspark.sql.types import Row from llmfoundry.utils import maybe_create_mosaicml_logger -from llmfoundry.utils.exceptions import (ClusterDoesNotExistError, - FailedToConnectToDatabricksError, - FailedToCreateSQLConnectionError) +from llmfoundry.utils.exceptions import ( + ClusterDoesNotExistError, + FailedToConnectToDatabricksError, + FailedToCreateSQLConnectionError, +) MINIMUM_DB_CONNECT_DBR_VERSION = '14.1' MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2' @@ -44,8 +46,14 @@ log = logging.getLogger(__name__) Result = namedtuple( - 'Result', ['url', 'row_count', 'compressed_size', 'uncompressed_size' - ]) # pyright: ignore + 'Result', + [ + 'url', + 'row_count', + 'compressed_size', + 'uncompressed_size', + ], +) # pyright: ignore # ``collect_as_cf`` is an addon new feature monkey patch on top of the DB Connect package. # It allows the client to fetch the results in different formats from the server. @@ -84,7 +92,7 @@ def to_cf(self: SparkConnectClient, format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW else: raise ValueError( - f'Only formats json, csv, and arrow are supported. Got invalid type {type}' + f'Only formats json, csv, and arrow are supported. Got invalid type {type}', ) ro = cloud_pb2.ResultOptions( @@ -92,16 +100,21 @@ def to_cf(self: SparkConnectClient, cloudOptions=cloud_pb2.ResultOptions.CloudOptions( format=format, useCompression=False, - )) + ), + ) cloud_option = any_pb2.Any() cloud_option.Pack(ro) req.request_options.append( - pb2.ExecutePlanRequest.RequestOption(extension=cloud_option)) + pb2.ExecutePlanRequest.RequestOption(extension=cloud_option), + ) # Create the iterator - iterator = ExecutePlanResponseReattachableIterator(req, self._stub, - self._retry_policy, - self._builder.metadata()) + iterator = ExecutePlanResponseReattachableIterator( + req, + self._stub, + self._retry_policy, + self._builder.metadata(), + ) # Iterate over the response result = [] row_count = 0 @@ -109,15 +122,21 @@ def to_cf(self: SparkConnectClient, for response in iterator: if response.HasField('extension') and response.extension.Is( - cloud_pb2.CloudResultBatch.DESCRIPTOR): + cloud_pb2.CloudResultBatch.DESCRIPTOR, + ): batch = cloud_pb2.CloudResultBatch() if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): raise ValueError( - 'Response extension is not of type CloudResultBatch.') + 'Response extension is not of type CloudResultBatch.', + ) response.extension.Unpack(batch) result += [ - Result(b.url, b.row_count, b.compressed_size, - b.uncompressed_size) for b in batch.results + Result( + b.url, + b.row_count, + b.compressed_size, + b.uncompressed_size, + ) for b in batch.results ] row_count += sum(result.row_count for result in batch.results) is_overflow |= batch.truncated @@ -175,7 +194,7 @@ def run_query( method: str, cursor: Optional[Cursor] = None, spark: Optional[SparkSession] = None, - collect: bool = True + collect: bool = True, ) -> Optional[Union[List[Row], DataFrame, SparkDataFrame]]: """Run SQL query via databricks-connect or databricks-sql. @@ -208,12 +227,14 @@ def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable: yield (i, r.url, json_output_folder, columns) -def download(ipart: int, - url: str, - json_output_folder: str, - columns: Optional[List] = None, - resp_format: str = 'arrow', - compressed: bool = False) -> None: +def download( + ipart: int, + url: str, + json_output_folder: str, + columns: Optional[List] = None, + resp_format: str = 'arrow', + compressed: bool = False, +) -> None: """Thread download presigned url and save to jsonl locally. Args: @@ -228,10 +249,14 @@ def download(ipart: int, if resp.status_code == 200: if resp_format == 'json': data = resp.json() - pd.DataFrame(data, columns=columns).to_json(os.path.join( - json_output_folder, 'part_' + str(ipart) + '.jsonl'), - orient='records', - lines=True) + pd.DataFrame(data, columns=columns).to_json( + os.path.join( + json_output_folder, + 'part_' + str(ipart) + '.jsonl', + ), + orient='records', + lines=True, + ) return # When resp_format is arrow: @@ -247,21 +272,29 @@ def download(ipart: int, # Convert the PyArrow table into a pandas DataFrame df = table.to_pandas() - df.to_json(os.path.join(json_output_folder, - 'part_' + str(ipart) + '.jsonl'), - orient='records', - lines=True, - force_ascii=False) + df.to_json( + os.path.join(json_output_folder, 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True, + force_ascii=False, + ) def download_starargs(args: Tuple) -> None: return download(*args) -def fetch_data(method: str, cursor: Optional[Cursor], - sparkSession: Optional[SparkSession], start: int, end: int, - order_by: str, tablename: str, columns_str: str, - json_output_folder: str) -> None: +def fetch_data( + method: str, + cursor: Optional[Cursor], + sparkSession: Optional[SparkSession], + start: int, + end: int, + order_by: str, + tablename: str, + columns_str: str, + json_output_folder: str, +) -> None: """Fetches a specified range of rows from a given table to a json file. This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, @@ -297,7 +330,8 @@ def fetch_data(method: str, cursor: Optional[Cursor], spark_df = run_query(query, method, cursor, sparkSession, collect=False) if spark_df is None: raise RuntimeError( - f'Expect spark dataframe with {query} but got None') + f'Expect spark dataframe with {query} but got None', + ) pdf = spark_df.toPandas() # pyright: ignore else: # method == 'dbsql': ans = run_query(query, method, cursor, sparkSession, collect=True) @@ -306,9 +340,11 @@ def fetch_data(method: str, cursor: Optional[Cursor], records = [r.asDict() for r in ans] # pyright: ignore pdf = pd.DataFrame.from_dict(records) - pdf.to_json(os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'), - orient='records', - lines=True) + pdf.to_json( + os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'), + orient='records', + lines=True, + ) def fetch( @@ -334,29 +370,37 @@ def fetch( cursor = dbsql.cursor() if dbsql is not None else None try: - ans = run_query(f'SELECT COUNT(*) FROM {tablename}', method, cursor, - sparkSession) + ans = run_query( + f'SELECT COUNT(*) FROM {tablename}', + method, + cursor, + sparkSession, + ) nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore log.info(f'total_rows = {nrows}') except Exception as e: raise RuntimeError( - f'Error in get total rows from {tablename}. Restart sparkSession and try again' + f'Error in get total rows from {tablename}. Restart sparkSession and try again', ) from e try: - ans = run_query(f'SHOW COLUMNS IN {tablename}', method, cursor, - sparkSession) + ans = run_query( + f'SHOW COLUMNS IN {tablename}', + method, + cursor, + sparkSession, + ) columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore order_by = columns[0] columns_str = ','.join(columns) log.info(f'order by column {order_by}') except Exception as e: raise RuntimeError( - f'Error in get columns from {tablename}. Restart sparkSession and try again' + f'Error in get columns from {tablename}. Restart sparkSession and try again', ) from e if method == 'dbconnect' and sparkSession is not None: - log.info('processes = ', processes) + log.info(f'{processes=}') df = sparkSession.table(tablename) # Running the query and collecting the data as arrow or json. @@ -375,18 +419,29 @@ def fetch( for start in range(0, nrows, batch_size): log.warning(f'batch {start}') end = min(start + batch_size, nrows) - fetch_data(method, cursor, sparkSession, start, end, order_by, - tablename, columns_str, json_output_folder) + fetch_data( + method, + cursor, + sparkSession, + start, + end, + order_by, + tablename, + columns_str, + json_output_folder, + ) if cursor is not None: cursor.close() -def validate_and_get_cluster_info(cluster_id: str, - databricks_host: str, - databricks_token: str, - http_path: Optional[str], - use_serverless: bool = False) -> tuple: +def validate_and_get_cluster_info( + cluster_id: str, + databricks_host: str, + databricks_token: str, + http_path: Optional[str], + use_serverless: bool = False, +) -> tuple: """Validate and get cluster info for running the Delta to JSONL conversion. Args: @@ -411,18 +466,22 @@ def validate_and_get_cluster_info(cluster_id: str, stripped_runtime = re.sub( r'[a-zA-Z]', '', - res.spark_version.split('-scala')[0].replace( # type: ignore - 'x-snapshot', '')) + res.spark_version.split('-scala') + [0].replace( # type: ignore + 'x-snapshot', '', + ), + ) runtime_version = re.sub(r'[.-]*$', '', stripped_runtime) - if version.parse(runtime_version) < version.parse( - MINIMUM_SQ_CONNECT_DBR_VERSION): + if version.parse( + runtime_version, + ) < version.parse(MINIMUM_SQ_CONNECT_DBR_VERSION): raise ValueError( - f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}' + f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}', ) if http_path is None and version.parse( - runtime_version) >= version.parse( - MINIMUM_DB_CONNECT_DBR_VERSION): + runtime_version, + ) >= version.parse(MINIMUM_DB_CONNECT_DBR_VERSION): method = 'dbconnect' if method == 'dbconnect': @@ -430,14 +489,17 @@ def validate_and_get_cluster_info(cluster_id: str, if use_serverless: session_id = str(uuid4()) sparkSession = DatabricksSession.builder.host( - databricks_host).token(databricks_token).header( - 'x-databricks-session-id', session_id).getOrCreate() + databricks_host, + ).token( + databricks_token, + ).header('x-databricks-session-id', session_id).getOrCreate() else: sparkSession = DatabricksSession.builder.remote( host=databricks_host, token=databricks_token, - cluster_id=cluster_id).getOrCreate() + cluster_id=cluster_id, + ).getOrCreate() except Exception as e: raise FailedToConnectToDatabricksError() from e @@ -462,13 +524,15 @@ def fetch_DT(args: Namespace) -> None: obj = urllib.parse.urlparse(args.json_output_folder) if obj.scheme != '': raise ValueError( - 'Check the json_output_folder and verify it is a local path!') + 'Check the json_output_folder and verify it is a local path!', + ) if os.path.exists(args.json_output_folder): if not os.path.isdir(args.json_output_folder) or os.listdir( - args.json_output_folder): + args.json_output_folder, + ): raise RuntimeError( - f'Output folder {args.json_output_folder} already exists and is not empty. Please remove it and retry.' + f'Output folder {args.json_output_folder} already exists and is not empty. Please remove it and retry.', ) os.makedirs(args.json_output_folder, exist_ok=True) @@ -483,10 +547,18 @@ def fetch_DT(args: Namespace) -> None: databricks_host=args.DATABRICKS_HOST, databricks_token=args.DATABRICKS_TOKEN, http_path=args.http_path, - use_serverless=args.use_serverless) + use_serverless=args.use_serverless, + ) - fetch(method, args.delta_table_name, args.json_output_folder, - args.batch_size, args.processes, sparkSession, dbsql) + fetch( + method, + args.delta_table_name, + args.json_output_folder, + args.batch_size, + args.processes, + sparkSession, + dbsql, + ) if dbsql is not None: dbsql.close() @@ -494,41 +566,53 @@ def fetch_DT(args: Namespace) -> None: # combine downloaded jsonl into one big jsonl for IFT iterative_combine_jsons( args.json_output_folder, - os.path.join(args.json_output_folder, args.json_output_filename)) + os.path.join(args.json_output_folder, args.json_output_filename), + ) if __name__ == '__main__': parser = ArgumentParser( description= - 'Download delta table from UC and convert to json to save local') - parser.add_argument('--delta_table_name', - required=True, - type=str, - help='UC table ..') - parser.add_argument('--json_output_folder', - required=True, - type=str, - help='Local path to save the converted json') - parser.add_argument('--http_path', - required=False, - type=str, - help='http_path is set then dbsql method is used') - parser.add_argument('--batch_size', - required=False, - type=int, - default=1 << 30, - help='row chunks to transmit a time to avoid OOM') - parser.add_argument('--processes', - required=False, - type=int, - default=os.cpu_count(), - help='number of processes allowed to use') + 'Download delta table from UC and convert to json to save local', + ) + parser.add_argument( + '--delta_table_name', + required=True, + type=str, + help='UC table ..
', + ) + parser.add_argument( + '--json_output_folder', + required=True, + type=str, + help='Local path to save the converted json', + ) + parser.add_argument( + '--http_path', + required=False, + type=str, + help='http_path is set then dbsql method is used', + ) + parser.add_argument( + '--batch_size', + required=False, + type=int, + default=1 << 30, + help='row chunks to transmit a time to avoid OOM', + ) + parser.add_argument( + '--processes', + required=False, + type=int, + default=os.cpu_count(), + help='number of processes allowed to use', + ) parser.add_argument( '--cluster_id', required=False, type=str, help= - 'cluster id has runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.' + 'cluster id has runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.', ) parser.add_argument( '--use_serverless', @@ -536,7 +620,7 @@ def fetch_DT(args: Namespace) -> None: type=bool, default=False, help= - 'Use serverless or not. Make sure the workspace is entitled with serverless' + 'Use serverless or not. Make sure the workspace is entitled with serverless', ) parser.add_argument( '--json_output_filename', @@ -544,7 +628,7 @@ def fetch_DT(args: Namespace) -> None: type=str, default='train-00000-of-00001.jsonl', help= - 'The name of the combined final jsonl that combines all partitioned jsonl' + 'The name of the combined final jsonl that combines all partitioned jsonl', ) args = parser.parse_args() mosaicml_logger = maybe_create_mosaicml_logger() @@ -556,7 +640,7 @@ def fetch_DT(args: Namespace) -> None: tik = time.time() fetch_DT(args) - log.info('Elapsed time', time.time() - tik) + log.info(f'Elapsed time {time.time() - tik}') except Exception as e: if mosaicml_logger is not None: diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py index fb6bde4115..523d45093d 100644 --- a/scripts/data_prep/convert_finetuning_dataset.py +++ b/scripts/data_prep/convert_finetuning_dataset.py @@ -16,10 +16,12 @@ from tqdm import tqdm from llmfoundry.data.finetuning.collator import validate_target_settings -from llmfoundry.data.finetuning.tasks import (_get_example_type, - dataset_constructor, - is_valid_ift_example, - tokenize_formatted_example) +from llmfoundry.data.finetuning.tasks import ( + _get_example_type, + dataset_constructor, + is_valid_ift_example, + tokenize_formatted_example, +) from llmfoundry.utils.builders import build_tokenizer HFDataset = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset] @@ -33,16 +35,20 @@ def parse_args() -> Namespace: type=str, required=True, help= - 'Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`)' + 'Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`)', + ) + parser.add_argument( + '--data_subset', + type=str, + default=None, + help='(Optional) subset of data to use.', + ) + parser.add_argument( + '--splits', + nargs='+', + default=['train', 'validation'], + help='Which splits of the dataset to convert.', ) - parser.add_argument('--data_subset', - type=str, - default=None, - help='(Optional) subset of data to use.') - parser.add_argument('--splits', - nargs='+', - default=['train', 'validation'], - help='Which splits of the dataset to convert.') parser.add_argument('--preprocessor', type=str, default=None, @@ -53,32 +59,34 @@ def parse_args() -> Namespace: nargs='+', default=[], help= - 'Data file for each split. If set, its length should be exact same as len(splits)' + 'Data file for each split. If set, its length should be exact same as len(splits)', ) parser.add_argument( '--skip-preprocessing', action='store_true', help= - 'Whether to skip preprocessing (e.g., if the dataset is already formatted correctly)' + 'Whether to skip preprocessing (e.g., if the dataset is already formatted correctly)', ) parser.add_argument( '--out_root', type=str, required=True, help= - 'Root path of output directory where MDS shards will be stored. Can be a remote URI.' + 'Root path of output directory where MDS shards will be stored. Can be a remote URI.', ) parser.add_argument( '--local', type=str, default=None, help= - '(Optional) root path of local directory if you want to keep a local copy when out_root is remote.' + '(Optional) root path of local directory if you want to keep a local copy when out_root is remote.', + ) + parser.add_argument( + '--compression', + type=str, + default=None, + help='(Optional) name of compression algorithm to use.', ) - parser.add_argument('--compression', - type=str, - default=None, - help='(Optional) name of compression algorithm to use.') parser.add_argument('--num_workers', type=int, required=False, default=None) parser.add_argument('--tokenizer', type=str, required=False, default=None) parser.add_argument('--tokenizer_kwargs', type=str, required=False) @@ -90,30 +98,30 @@ def parse_args() -> Namespace: type=str, default='none', help='Used to determine which samples are valid at max_seq_len. ' +\ - 'This is the policy for when to use prompts as training targets. Default "none" means prompts are never used as training targets.' + 'This is the policy for when to use prompts as training targets. Default "none" means prompts are never used as training targets.', ) parser.add_argument( '--target_responses', type=str, default='last', help='Used to determine which samples are valid at max_seq_len. ' +\ - 'This is the policy for which responses to treat as training targets. Default "last" means the only the final response (if multi-turn) is used.' + 'This is the policy for which responses to treat as training targets. Default "last" means the only the final response (if multi-turn) is used.', ) parser.add_argument( '--encoder_decoder', action='store_true', help='Used to determine which samples are valid at max_seq_len. ' +\ 'Set this flag if the data are intended to be used to train an encoder-decoder model. If so, you must use the default ' +\ - '``target_prompts`` and ``target_responses`` settings of "none" and "last", respectively.' + '``target_prompts`` and ``target_responses`` settings of "none" and "last", respectively.', ) parsed = parser.parse_args() if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set( - parsed.splits))) > 0: + set(os.listdir(parsed.out_root)).intersection(set(parsed.splits)), + ) > 0: raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.' + f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.', ) if parsed.tokenizer_kwargs is not None: @@ -121,18 +129,21 @@ def parse_args() -> Namespace: else: parsed.tokenizer_kwargs = {} - if len(parsed.data_files) > 0 and len(parsed.data_files) != len( - parsed.splits): + if len(parsed.data_files) > 0 and len( + parsed.data_files, + ) != len(parsed.splits): raise ValueError( - f'If data_files is set, data_files and splits must have the same length. Got {len(parsed.data_files)=} while {len(parsed.splits)=}' + f'If data_files is set, data_files and splits must have the same length. Got {len(parsed.data_files)=} while {len(parsed.splits)=}', ) return parsed -def build_dataloader(dataset: HFDataset, - batch_size: int, - num_workers: Optional[int] = None) -> DataLoader: +def build_dataloader( + dataset: HFDataset, + batch_size: int, + num_workers: Optional[int] = None, +) -> DataLoader: if num_workers is None: # Multiple workers is only supported on linux machines if 'linux' in platform.platform().lower(): @@ -148,8 +159,10 @@ def build_dataloader(dataset: HFDataset, if 'macos' in platform.platform().lower() and num_workers == 0: prefetch_factor = None else: - prefetch_factor = max(1, 2 * batch_size // - num_workers) if num_workers > 0 else 2 + prefetch_factor = max( + 1, + 2 * batch_size // num_workers, + ) if num_workers > 0 else 2 return DataLoader( dataset=dataset, @@ -161,8 +174,8 @@ def build_dataloader(dataset: HFDataset, def generate_samples( - loader: DataLoader, - truncate_num_samples: Optional[int] = None + loader: DataLoader, + truncate_num_samples: Optional[int] = None, ) -> Iterable[Dict[str, bytes]]: """Generator over samples of a dataloader. @@ -184,8 +197,11 @@ def generate_samples( yield {k: v[idx] for k, v in batch.items()} -def get_columns_and_format(dataset: HFDataset, tokenizing: bool, - preprocessing_fn: Callable): +def get_columns_and_format( + dataset: HFDataset, + tokenizing: bool, + preprocessing_fn: Callable, +): ex = preprocessing_fn(next(iter(dataset))) example_type = _get_example_type(ex) if tokenizing: @@ -209,13 +225,15 @@ def main(args: Namespace) -> None: else: preprocessor_str = args.preprocessor preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str( - preprocessor=preprocessor_str, dataset_name=args.dataset) + preprocessor=preprocessor_str, + dataset_name=args.dataset, + ) if preprocessing_fn is None: raise ValueError( '`args.preprocessor` was not set and no preprocessing function ' +\ 'has been registered for `args.dataset`. If this was intentional ' +\ '(e.g., because your dataset is already correctly formatted), ' +\ - 'include the "--skip-preprocessing" flag to avoid this error.' + 'include the "--skip-preprocessing" flag to avoid this error.', ) # Make sure the target settings are valid @@ -235,23 +253,28 @@ def main(args: Namespace) -> None: data_file = None if len(args.data_files) > 0: data_file = args.data_files[i] - dataset = hf_datasets.load_dataset(path=args.dataset, - name=args.data_subset, - split=split_name, - data_files=data_file, - streaming=True) + dataset = hf_datasets.load_dataset( + path=args.dataset, + name=args.data_subset, + split=split_name, + data_files=data_file, + streaming=True, + ) # Determine the output columns columns, example_type = get_columns_and_format( dataset=dataset, tokenizing=tokenizer is not None, - preprocessing_fn=preprocessing_fn) + preprocessing_fn=preprocessing_fn, + ) # Prepare the iterables if example_type == 'chat': samples = iter(dataset) else: - loader = build_dataloader(dataset=dataset, - batch_size=512, - num_workers=args.num_workers) + loader = build_dataloader( + dataset=dataset, + batch_size=512, + num_workers=args.num_workers, + ) samples = generate_samples(loader) # Write samples @@ -262,10 +285,12 @@ def main(args: Namespace) -> None: keep_local = True else: keep_local = False - with MDSWriter(columns=columns, - out=out, - compression=args.compression, - keep_local=keep_local) as out: + with MDSWriter( + columns=columns, + out=out, + compression=args.compression, + keep_local=keep_local, + ) as out: examples_removed = 0 for sample in tqdm(samples, desc=split_name): formatted_sample = preprocessing_fn(sample) @@ -278,17 +303,20 @@ def main(args: Namespace) -> None: except Exception as e: raise ValueError( 'Encountered an error when checking example for proper formatting. ' +\ - f'example={formatted_sample}' + f'example={formatted_sample}', ) from e if tokenizer is not None: - sample = tokenize_formatted_example(formatted_sample, - tokenizer=tokenizer) + sample = tokenize_formatted_example( + formatted_sample, + tokenizer=tokenizer, + ) if not is_valid_ift_example( - args.max_seq_len, - target_prompts=args.target_prompts, - target_responses=args.target_responses, - decoder_only_format=not args.encoder_decoder, - example=sample): + args.max_seq_len, + target_prompts=args.target_prompts, + target_responses=args.target_responses, + decoder_only_format=not args.encoder_decoder, + example=sample, + ): examples_removed += 1 continue @@ -314,7 +342,7 @@ def main(args: Namespace) -> None: warnings.warn( f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, ' + - 'the prompt or response was empty, or the response was all padding tokens.' + 'the prompt or response was empty, or the response was all padding tokens.', ) diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index 636e85abed..1e36a681f9 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -11,18 +11,25 @@ from typing import Iterable, List, Tuple, cast import psutil -from composer.utils import (ObjectStore, maybe_create_object_store_from_uri, - parse_uri) +from composer.utils import ( + ObjectStore, + maybe_create_object_store_from_uri, + parse_uri, +) from streaming import MDSWriter from tqdm import tqdm from transformers import AutoTokenizer from llmfoundry.data import ConcatTokensDataset from llmfoundry.utils import maybe_create_mosaicml_logger -from llmfoundry.utils.data_prep_utils import (DownloadingIterable, - merge_shard_groups) -from llmfoundry.utils.exceptions import (InputFolderMissingDataError, - OutputFolderNotEmptyError) +from llmfoundry.utils.data_prep_utils import ( + DownloadingIterable, + merge_shard_groups, +) +from llmfoundry.utils.exceptions import ( + InputFolderMissingDataError, + OutputFolderNotEmptyError, +) log = logging.getLogger(__name__) @@ -129,10 +136,12 @@ def parse_args() -> Namespace: # Ensure that eos text is not specified twice. if parsed.eos_text is not None: parser.error( - 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.' + 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.', ) tokenizer = AutoTokenizer.from_pretrained( - parsed.tokenizer, trust_remote_code=parsed.trust_remote_code) + parsed.tokenizer, + trust_remote_code=parsed.trust_remote_code, + ) parsed.eos_text = tokenizer.eos_token # now that we have validated them, change BOS/EOS to strings @@ -254,11 +263,15 @@ def download_and_convert( # Download file_names with tempfile.TemporaryDirectory() as tmp_dir: - downloading_iter = DownloadingIterable(object_names=file_names, - output_folder=tmp_dir, - object_store=object_store) + downloading_iter = DownloadingIterable( + object_names=file_names, + output_folder=tmp_dir, + object_store=object_store, + ) tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, trust_remote_code=trust_remote_code) + tokenizer_name, + trust_remote_code=trust_remote_code, + ) tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up @@ -275,9 +288,11 @@ def download_and_convert( columns = {'tokens': 'bytes'} log.info('Converting to MDS format...') - with MDSWriter(out=output_folder, - columns=columns, - compression=compression) as out: + with MDSWriter( + out=output_folder, + columns=columns, + compression=compression, + ) as out: for sample in tqdm(dataset): out.write(sample) @@ -292,8 +307,11 @@ def is_remote_path(path: str) -> bool: return backend != '' -def is_already_processed(output_root: str, args_str: str, - object_names: List[str]) -> bool: +def is_already_processed( + output_root: str, + args_str: str, + object_names: List[str], +) -> bool: """Determines whether a group of text files has already been processed. Checks the done fie at output root to determine this. @@ -313,7 +331,8 @@ def is_already_processed(output_root: str, args_str: str, done_file = os.path.join(tmp_dir, DONE_FILENAME) output_object_store.download_object( os.path.join(output_folder_prefix, DONE_FILENAME), - done_file) + done_file, + ) with open(done_file) as df: done_file_contents = df.read().splitlines() except FileNotFoundError: @@ -392,12 +411,15 @@ def convert_text_to_mds( raise InputFolderMissingDataError(input_folder) # Check if the text files in the bucket have already been processed. - if not reprocess and is_already_processed(output_folder, args_str, - object_names): + if not reprocess and is_already_processed( + output_folder, + args_str, + object_names, + ): log.info( f'Input folder {input_folder} is already processed at {output_folder} and ' + - 'reprocess is set to False. Set reprocess to True if you would like to force reprocessing.' + 'reprocess is set to False. Set reprocess to True if you would like to force reprocessing.', ) return @@ -410,18 +432,37 @@ def convert_text_to_mds( if processes > 1: # Download and convert the text files in parallel - args = get_task_args(object_names, local_output_folder, input_folder, - processes, tokenizer_name, concat_tokens, eos_text, - bos_text, no_wrap, compression, trust_remote_code) + args = get_task_args( + object_names, + local_output_folder, + input_folder, + processes, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) with ProcessPoolExecutor(max_workers=processes) as executor: list(executor.map(download_and_convert_starargs, args)) # Merge the mds shards from each of the processes into a single folder merge_shard_groups(local_output_folder) else: - download_and_convert(object_names, local_output_folder, input_folder, - tokenizer_name, concat_tokens, eos_text, bos_text, - no_wrap, compression, trust_remote_code) + download_and_convert( + object_names, + local_output_folder, + input_folder, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) # Write a done file with the args and object names write_done_file(local_output_folder, args_str, object_names) @@ -429,7 +470,9 @@ def convert_text_to_mds( if is_remote_output: # Upload the local output to the remote location output_object_store = cast( - ObjectStore, maybe_create_object_store_from_uri(output_folder)) + ObjectStore, + maybe_create_object_store_from_uri(output_folder), + ) _, _, output_folder_prefix = parse_uri(output_folder) files_to_upload = os.listdir(local_output_folder) @@ -437,7 +480,9 @@ def convert_text_to_mds( assert not os.path.isdir(file) remote_path = os.path.join(output_folder_prefix, file) output_object_store.upload_object( - remote_path, os.path.join(local_output_folder, file)) + remote_path, + os.path.join(local_output_folder, file), + ) def _args_str(original_args: Namespace) -> str: @@ -468,18 +513,20 @@ def _args_str(original_args: Namespace) -> str: mosaicml_logger = maybe_create_mosaicml_logger() try: - convert_text_to_mds(tokenizer_name=args.tokenizer, - output_folder=args.output_folder, - input_folder=args.input_folder, - concat_tokens=args.concat_tokens, - eos_text=args.eos_text, - bos_text=args.bos_text, - no_wrap=args.no_wrap, - compression=args.compression, - processes=args.processes, - reprocess=args.reprocess, - trust_remote_code=args.trust_remote_code, - args_str=_args_str(args)) + convert_text_to_mds( + tokenizer_name=args.tokenizer, + output_folder=args.output_folder, + input_folder=args.input_folder, + concat_tokens=args.concat_tokens, + eos_text=args.eos_text, + bos_text=args.bos_text, + no_wrap=args.no_wrap, + compression=args.compression, + processes=args.processes, + reprocess=args.reprocess, + trust_remote_code=args.trust_remote_code, + args_str=_args_str(args), + ) except Exception as e: if mosaicml_logger is not None: mosaicml_logger.log_exception(e) diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 1e1cfd1bf2..d7b104100e 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -19,16 +19,26 @@ from omegaconf import OmegaConf as om from rich.traceback import install -from llmfoundry.utils import (find_mosaicml_logger, log_eval_analytics, - maybe_create_mosaicml_logger) +from llmfoundry.utils import ( + find_mosaicml_logger, + log_eval_analytics, + maybe_create_mosaicml_logger, +) install() -from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, - build_callback, build_composer_model, - build_evaluators, build_logger, - build_tokenizer) -from llmfoundry.utils.config_utils import (log_config, pop_config, - process_init_device) +from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, + build_callback, + build_composer_model, + build_evaluators, + build_logger, + build_tokenizer, +) +from llmfoundry.utils.config_utils import ( + log_config, + pop_config, + process_init_device, +) from llmfoundry.utils.registry_utils import import_file log = logging.getLogger(__name__) @@ -59,9 +69,10 @@ def evaluate_model( log.info(f'Evaluating model: {model_cfg.model_name}') # Build tokenizer and model - tokenizer_cfg: Dict[str, - Any] = om.to_container(model_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = om.to_container( + model_cfg.tokenizer, + resolve=True, + ) # type: ignore tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) @@ -96,7 +107,8 @@ def evaluate_model( if fsdp_config and model_cfg.model.get('load_in_8bit', False): raise ValueError( 'The FSDP config block is not supported when loading ' + - 'Hugging Face models in 8bit.') + 'Hugging Face models in 8bit.', + ) init_context = process_init_device(model_cfg.model, fsdp_config) @@ -110,21 +122,23 @@ def evaluate_model( # Now add the eval metrics if eval_loader_config is not None: train_metrics = composer_model.get_metrics(is_train=True) - evaluators = add_metrics_to_eval_loaders(evaluators, - list(train_metrics.keys())) + evaluators = add_metrics_to_eval_loaders( + evaluators, + list(train_metrics.keys()), + ) if eval_gauntlet_df is None and eval_gauntlet_callback is not None: eval_gauntlet_df = pd.DataFrame( - columns=['model_name'] + - [avg for avg in eval_gauntlet_callback.averages] + - [t.name for t in eval_gauntlet_callback.categories]) + columns=['model_name'] + list(eval_gauntlet_callback.averages) + + [t.name for t in eval_gauntlet_callback.categories], + ) load_path = model_cfg.get('load_path', None) if model_cfg.model.name == 'mpt_causal_lm' and load_path is None: raise ValueError( 'MPT causal LMs require a load_path to the checkpoint for model evaluation.' + - ' Please check your yaml and the model_cfg to ensure that load_path is set.' + ' Please check your yaml and the model_cfg to ensure that load_path is set.', ) assert composer_model is not None @@ -154,8 +168,10 @@ def evaluate_model( if torch.cuda.is_available(): torch.cuda.synchronize() a = time.time() - trainer.eval(eval_dataloader=evaluators, - subset_num_batches=eval_subset_num_batches) + trainer.eval( + eval_dataloader=evaluators, + subset_num_batches=eval_subset_num_batches, + ) if torch.cuda.is_available(): torch.cuda.synchronize() b = time.time() @@ -166,11 +182,13 @@ def evaluate_model( def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: # Run user provided code if specified - code_paths = pop_config(cfg, - 'code_paths', - must_exist=False, - default_value=[], - convert=True) + code_paths = pop_config( + cfg, + 'code_paths', + must_exist=False, + default_value=[], + convert=True, + ) for code_path in code_paths: import_file(code_path) @@ -180,81 +198,104 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: logged_cfg: DictConfig = copy.deepcopy(cfg) model_configs: ListConfig = pop_config(cfg, 'models', must_exist=True) - eval_gauntlet_config: Optional[Union[str, DictConfig]] = pop_config( - cfg, 'eval_gauntlet', must_exist=False, default_value=None) - - fsdp_dict_cfg: Optional[DictConfig] = pop_config(cfg, - 'fsdp_config', - must_exist=False, - default_value=None) + eval_gauntlet_config: Optional[ + Union[str, DictConfig] + ] = pop_config(cfg, 'eval_gauntlet', must_exist=False, default_value=None) + + fsdp_dict_cfg: Optional[DictConfig] = pop_config( + cfg, + 'fsdp_config', + must_exist=False, + default_value=None, + ) fsdp_config: Optional[Dict] = om.to_container( fsdp_dict_cfg, - resolve=True) if fsdp_dict_cfg is not None else None # type: ignore + resolve=True, + ) if fsdp_dict_cfg is not None else None # type: ignore assert isinstance(fsdp_config, Dict) or fsdp_config is None # Mandatory Evaluation Parameters - icl_tasks: Union[str, ListConfig] = pop_config(cfg, - 'icl_tasks', - must_exist=True) + icl_tasks: Union[ + str, ListConfig] = pop_config(cfg, 'icl_tasks', must_exist=True) max_seq_len: int = pop_config(cfg, 'max_seq_len', must_exist=True) - device_eval_batch_size: int = pop_config(cfg, - 'device_eval_batch_size', - must_exist=True) - precision: str = pop_config(cfg, - 'precision', - must_exist=False, - default_value=None) - python_log_level: Optional[str] = pop_config(cfg, - 'python_log_level', - must_exist=False, - default_value='debug') + device_eval_batch_size: int = pop_config( + cfg, + 'device_eval_batch_size', + must_exist=True, + ) + precision: str = pop_config( + cfg, + 'precision', + must_exist=False, + default_value=None, + ) + python_log_level: Optional[str] = pop_config( + cfg, + 'python_log_level', + must_exist=False, + default_value='debug', + ) # Optional Evaluation Parameters with default values - eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config( - cfg, 'eval_loader', must_exist=False, default_value=None) + eval_loader_config: Optional[ + Union[DictConfig, ListConfig] + ] = pop_config(cfg, 'eval_loader', must_exist=False, default_value=None) seed: int = pop_config(cfg, 'seed', must_exist=False, default_value=17) - dist_timeout: Union[float, int] = pop_config(cfg, - 'dist_timeout', - must_exist=False, - default_value=600.0) + dist_timeout: Union[float, int] = pop_config( + cfg, + 'dist_timeout', + must_exist=False, + default_value=600.0, + ) default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name: str = pop_config(cfg, - 'run_name', - must_exist=False, - default_value=default_run_name) - loggers_cfg: Dict[str, Any] = pop_config(cfg, - 'loggers', - must_exist=False, - default_value={}) - eval_subset_num_batches: int = pop_config(cfg, - 'eval_subset_num_batches', - must_exist=False, - default_value=-1) - icl_subset_num_batches: Optional[int] = pop_config(cfg, - 'icl_subset_num_batches', - must_exist=False, - default_value=None) - metadata: Optional[Dict[str, str]] = pop_config(cfg, - 'metadata', - must_exist=False, - default_value=None, - convert=True) - should_log_config: bool = pop_config(cfg, - 'log_config', - must_exist=False, - default_value=True) + run_name: str = pop_config( + cfg, + 'run_name', + must_exist=False, + default_value=default_run_name, + ) + loggers_cfg: Dict[ + str, + Any] = pop_config(cfg, 'loggers', must_exist=False, default_value={}) + eval_subset_num_batches: int = pop_config( + cfg, + 'eval_subset_num_batches', + must_exist=False, + default_value=-1, + ) + icl_subset_num_batches: Optional[int] = pop_config( + cfg, + 'icl_subset_num_batches', + must_exist=False, + default_value=None, + ) + metadata: Optional[Dict[str, str]] = pop_config( + cfg, + 'metadata', + must_exist=False, + default_value=None, + convert=True, + ) + should_log_config: bool = pop_config( + cfg, + 'log_config', + must_exist=False, + default_value=True, + ) # Pop out interpolation variables. pop_config(cfg, 'model_name_or_path', must_exist=False, default_value=None) - callback_configs: Optional[DictConfig] = pop_config(cfg, - 'callbacks', - must_exist=False, - default_value=None) + callback_configs: Optional[DictConfig] = pop_config( + cfg, + 'callbacks', + must_exist=False, + default_value=None, + ) # Warn for unused parameters for key in cfg: warnings.warn( - f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary.' + f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary.', ) reproducibility.seed_all(seed) @@ -265,7 +306,7 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: # Example of format string # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here format= - f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', ) logging.getLogger('llmfoundry').setLevel(python_log_level.upper()) @@ -288,8 +329,12 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: # mosaicml_logger will be None if the run isn't from the MosaicML platform if mosaicml_logger is not None: - log_eval_analytics(mosaicml_logger, model_configs, icl_tasks, - eval_gauntlet_config) + log_eval_analytics( + mosaicml_logger, + model_configs, + icl_tasks, + eval_gauntlet_config, + ) for model_cfg in model_configs: (trainer, logger_keys, eval_gauntlet_callback, @@ -313,12 +358,15 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: icl_subset_num_batches=icl_subset_num_batches, metadata=metadata, logged_config=logged_cfg, - should_log_config=should_log_config) + should_log_config=should_log_config, + ) trainers.append(trainer) if eval_gauntlet_callback is not None: composite_scores = eval_gauntlet_callback.eval_after_all( - trainer.state, trainer.logger) + trainer.state, + trainer.logger, + ) benchmark_to_taxonomy = {} if eval_gauntlet_callback is not None: @@ -326,9 +374,12 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: for b in t.benchmarks: benchmark_to_taxonomy[b.name] = t.name - model_results = calculate_markdown_results(logger_keys, trainer, - benchmark_to_taxonomy, - model_cfg.model_name) + model_results = calculate_markdown_results( + logger_keys, + trainer, + benchmark_to_taxonomy, + model_cfg.model_name, + ) if models_df is None: models_df = model_results @@ -338,17 +389,23 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: if eval_gauntlet_df is not None and eval_gauntlet_callback is not None: assert composite_scores is not None row = {'model_name': model_cfg['model_name']} - row.update( - {k.split('/')[-1]: v for k, v in composite_scores.items()}) - eval_gauntlet_df = pd.concat( - [eval_gauntlet_df, pd.DataFrame([row])], ignore_index=True) + row.update({ + k.split('/')[-1]: v for k, v in composite_scores.items() + }) + eval_gauntlet_df = pd.concat([ + eval_gauntlet_df, + pd.DataFrame([row]), + ], + ignore_index=True) print(f'Printing gauntlet results for all models') print( eval_gauntlet_df.sort_values( list(eval_gauntlet_callback.averages.keys())[0], - ascending=False).to_markdown(index=False)) + ascending=False, + ).to_markdown(index=False), + ) print(f'Printing complete results for all models') assert models_df is not None print(models_df.to_markdown(index=False)) @@ -358,9 +415,12 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: return trainers, eval_gauntlet_df -def calculate_markdown_results(logger_keys: List[str], trainer: Trainer, - benchmark_to_taxonomy: Dict[str, str], - model_name: str): +def calculate_markdown_results( + logger_keys: List[str], + trainer: Trainer, + benchmark_to_taxonomy: Dict[str, str], + model_name: str, +): results = {} for key in logger_keys: @@ -386,13 +446,19 @@ def calculate_markdown_results(logger_keys: List[str], trainer: Trainer, results[dl_name[1]][dl_name[0]][metric_name].append({ 'val': metric.compute(), - 'subcat': dl_name[-1] if len(dl_name) == 3 else 'no_subcat' + 'subcat': dl_name[-1] if len(dl_name) == 3 else 'no_subcat', }) - df = pd.DataFrame(columns=[ - 'Category', 'Benchmark', 'Subtask', 'Accuracy', 'Number few shot', - 'Model' - ]) + df = pd.DataFrame( + columns=[ + 'Category', + 'Benchmark', + 'Subtask', + 'Accuracy', + 'Number few shot', + 'Model', + ], + ) for num_shot in results: for benchmark in results[num_shot]: @@ -405,7 +471,7 @@ def calculate_markdown_results(logger_keys: List[str], trainer: Trainer, 'Subtask': None, 'Accuracy': subscores[0]['val'], 'Number few shot': num_shot, - 'Model': model_name + 'Model': model_name, } df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) else: @@ -421,7 +487,7 @@ def calculate_markdown_results(logger_keys: List[str], trainer: Trainer, 'Number few shot': num_shot, 'Model': - model_name + model_name, } df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) for sub in subscores: @@ -437,7 +503,7 @@ def calculate_markdown_results(logger_keys: List[str], trainer: Trainer, 'Number few shot': num_shot, 'Model': - model_name + model_name, } df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) diff --git a/scripts/inference/benchmarking/benchmark.py b/scripts/inference/benchmarking/benchmark.py index 00daf6b559..463183442d 100644 --- a/scripts/inference/benchmarking/benchmark.py +++ b/scripts/inference/benchmarking/benchmark.py @@ -22,14 +22,16 @@ def get_dtype(dtype: str): else: raise NotImplementedError( f'dtype {dtype} is not supported. ' + - f'We only support fp32, fp16, and bf16 currently') + f'We only support fp32, fp16, and bf16 currently', + ) def compare_dtype(dtype: torch.dtype, param_dtype: torch.dtype): if dtype != param_dtype: raise ValueError( f'dtype type is: {dtype} but model dtype is: {param_dtype}. ' + - f"The expected dtype and model dtype don't match.") + f"The expected dtype and model dtype don't match.", + ) def main(config: DictConfig): @@ -54,7 +56,7 @@ def main(config: DictConfig): 'replace_method': 'auto', 'enable_cuda_graph': False, 'tensor_parallel': { - 'tp_size': 0 + 'tp_size': 0, }, } @@ -87,17 +89,18 @@ def main(config: DictConfig): print('n_params is: ', n_params) print( - 'name, latency (s), throughput (tokens/s), latency_per_sequence_output_token (ms)' + 'name, latency (s), throughput (tokens/s), latency_per_sequence_output_token (ms)', ) print('=' * 75) for batch_size in config.batch_sizes: for input_length in config.input_lengths: for output_length in config.output_lengths: - batch = torch.randint(0, - config.model.vocab_size - 1, - size=(batch_size, - input_length)).to(device) + batch = torch.randint( + 0, + config.model.vocab_size - 1, + size=(batch_size, input_length), + ).to(device) # We're just going to have generate eos, padding tokens be # ignored by HF generate @@ -111,12 +114,14 @@ def main(config: DictConfig): start_time = time.time() with torch.no_grad(): with autocast_context: - model.generate(batch, - max_new_tokens=output_length, - use_cache=config.use_cache, - attention_mask=attention_mask, - eos_token_id=None, - pad_token_id=None) + model.generate( + batch, + max_new_tokens=output_length, + use_cache=config.use_cache, + attention_mask=attention_mask, + eos_token_id=None, + pad_token_id=None, + ) torch.cuda.synchronize() mean_time = (time.time() - start_time) / config.num_batches @@ -127,7 +132,7 @@ def main(config: DictConfig): run_name = f'{config.benchmark_name}_{batch_size}_{input_length}_{output_length}' print( - f'{run_name}, {mean_time:.3f}, {tokens_per_second:.3f}, {ms_per_seq_output_token:.3f}' + f'{run_name}, {mean_time:.3f}, {tokens_per_second:.3f}, {ms_per_seq_output_token:.3f}', ) diff --git a/scripts/inference/convert_composer_mpt_to_ft.py b/scripts/inference/convert_composer_mpt_to_ft.py index f59eb6005a..bea5b6715e 100644 --- a/scripts/inference/convert_composer_mpt_to_ft.py +++ b/scripts/inference/convert_composer_mpt_to_ft.py @@ -14,16 +14,20 @@ from composer.utils import get_file, safe_torch_load from transformers import PreTrainedTokenizer -from llmfoundry.utils import (convert_and_save_ft_weights, - get_hf_tokenizer_from_composer_state_dict) +from llmfoundry.utils import ( + convert_and_save_ft_weights, + get_hf_tokenizer_from_composer_state_dict, +) -def save_ft_config(composer_config: Dict[str, Any], - tokenizer: PreTrainedTokenizer, - save_dir: str, - infer_gpu_num: int = 1, - weight_data_type: str = 'fp32', - force: bool = False): +def save_ft_config( + composer_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + save_dir: str, + infer_gpu_num: int = 1, + weight_data_type: str = 'fp32', + force: bool = False, +): config = configparser.ConfigParser() config['gpt'] = {} @@ -31,8 +35,9 @@ def save_ft_config(composer_config: Dict[str, Any], config['gpt']['model_name'] = 'mpt' config['gpt']['head_num'] = str(composer_config['n_heads']) n_embd = composer_config['d_model'] - config['gpt']['size_per_head'] = str(n_embd // - composer_config['n_heads']) + config['gpt']['size_per_head'] = str( + n_embd // composer_config['n_heads'], + ) config['gpt']['inter_size'] = str(n_embd * composer_config['mlp_ratio']) config['gpt']['max_pos_seq_len'] = str(composer_config['max_seq_len']) config['gpt']['num_layer'] = str(composer_config['n_layers']) @@ -48,11 +53,11 @@ def save_ft_config(composer_config: Dict[str, Any], config['gpt']['use_attention_linear_bias'] = str(True) if composer_config['attn_clip_qkv'] and not force: raise RuntimeError( - 'clip_qkv is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.' + 'clip_qkv is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.', ) if composer_config['attn_qk_ln'] and not force: raise RuntimeError( - 'qk_ln is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.' + 'qk_ln is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.', ) with open(os.path.join(save_dir, 'config.ini'), 'w') as configfile: @@ -64,13 +69,13 @@ def save_ft_config(composer_config: Dict[str, Any], def write_ft_checkpoint_from_composer_checkpoint( - checkpoint_path: Union[Path, str], - infer_gpu_num: int, - save_dir: str, - trust_remote_code: bool, - output_precision: str = 'fp32', - local_checkpoint_save_location: Optional[Union[Path, - str]] = None) -> None: + checkpoint_path: Union[Path, str], + infer_gpu_num: int, + save_dir: str, + trust_remote_code: bool, + output_precision: str = 'fp32', + local_checkpoint_save_location: Optional[Union[Path, str]] = None, +) -> None: """Convert a Composer checkpoint to a FasterTransformer checkpoint folder. .. note:: This function may not work properly if you used surgery algorithms when you trained your model. In that case you may need to @@ -96,11 +101,12 @@ def write_ft_checkpoint_from_composer_checkpoint( if local_checkpoint_save_location is None: tmp_dir = tempfile.TemporaryDirectory() local_checkpoint_save_location = Path( - tmp_dir.name) / 'local-composer-checkpoint.pt' + tmp_dir.name, + ) / 'local-composer-checkpoint.pt' # download the checkpoint file print( - f'Downloading checkpoint from {checkpoint_path} -> {local_checkpoint_save_location}' + f'Downloading checkpoint from {checkpoint_path} -> {local_checkpoint_save_location}', ) get_file(str(checkpoint_path), str(local_checkpoint_save_location)) @@ -112,13 +118,13 @@ def write_ft_checkpoint_from_composer_checkpoint( # Extract Composer config from state dict if 'state' not in composer_state_dict: raise RuntimeError( - f'"state" is not an available key in the provided composer checkpoint. Is {local_checkpoint_save_location} ill-formed?' + f'"state" is not an available key in the provided composer checkpoint. Is {local_checkpoint_save_location} ill-formed?', ) if 'integrations' not in composer_state_dict[ - 'state'] or 'huggingface' not in composer_state_dict['state'][ - 'integrations']: + 'state'] or 'huggingface' not in composer_state_dict['state'][ + 'integrations']: raise RuntimeError( - 'Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!' + 'Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!', ) composer_config = composer_state_dict['state']['integrations'][ 'huggingface']['model']['config']['content'] @@ -127,14 +133,18 @@ def write_ft_checkpoint_from_composer_checkpoint( print('#' * 30) print('Extracting HF Tokenizer...') hf_tokenizer = get_hf_tokenizer_from_composer_state_dict( - composer_state_dict, trust_remote_code) + composer_state_dict, + trust_remote_code, + ) if hf_tokenizer is None: print('Warning! No HF Tokenizer found!') # Extract the model weights weights_state_dict = composer_state_dict['state']['model'] torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( - weights_state_dict, prefix='model.') + weights_state_dict, + prefix='model.', + ) # Converting weights to desired dtype for k, v in weights_state_dict.items(): @@ -144,21 +154,25 @@ def write_ft_checkpoint_from_composer_checkpoint( # Convert the weights using the config and tokenizer to FasterTransformer format print('#' * 30) print('Saving FasterTransformer config...') - save_ft_config(composer_config, - tokenizer=hf_tokenizer, - save_dir=save_dir, - weight_data_type=output_precision) + save_ft_config( + composer_config, + tokenizer=hf_tokenizer, + save_dir=save_dir, + weight_data_type=output_precision, + ) print('#' * 30) print('Converting weights to FasterTransformer format...') - convert_and_save_ft_weights(named_params=weights_state_dict, - config=composer_config, - infer_gpu_num=infer_gpu_num, - weight_data_type=output_precision, - save_dir=save_dir) + convert_and_save_ft_weights( + named_params=weights_state_dict, + config=composer_config, + infer_gpu_num=infer_gpu_num, + weight_data_type=output_precision, + save_dir=save_dir, + ) print('#' * 30) print( - f'FasterTransformer checkpoint folder successfully created at {save_dir}.' + f'FasterTransformer checkpoint folder successfully created at {save_dir}.', ) print('Done.') @@ -169,37 +183,42 @@ def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( description= - 'Convert an MPT Composer checkpoint into a standard FasterTransformer checkpoint folder.' + 'Convert an MPT Composer checkpoint into a standard FasterTransformer checkpoint folder.', ) parser.add_argument( '--composer_path', '-i', type=str, help='Composer checkpoint path. Can be a local file path or cloud URI', - required=True) + required=True, + ) parser.add_argument( '--local_checkpoint_save_location', type=str, help='If specified, where to save the checkpoint file to locally. \ If the input ``checkpoint_path`` is already a local path, this will be a symlink. \ Defaults to None, which will use a temporary file.', - default=None) + default=None, + ) parser.add_argument( '--ft_save_dir', '-o', type=str, help='Directory to save FasterTransformer converted checkpoint in', - required=True) - parser.add_argument('--infer_gpu_num', - '-i_g', - type=int, - help='How many gpus for inference?', - required=True) + required=True, + ) + parser.add_argument( + '--infer_gpu_num', + '-i_g', + type=int, + help='How many gpus for inference?', + required=True, + ) parser.add_argument( '--force', action='store_true', help= - 'Force conversion to FT even if some features may not work as expected in FT' + 'Force conversion to FT even if some features may not work as expected in FT', ) parser.add_argument( '--output_precision', @@ -207,11 +226,13 @@ def parse_args() -> Namespace: help= 'Data type of weights in the FasterTransformer output model. Input checkpoint weights will be converted to this dtype.', choices=['fp32', 'fp16'], - default='fp32') + default='fp32', + ) parser.add_argument( '--trust_remote_code', action='store_true', - help='Whether or not to use code outside of transformers module.') + help='Whether or not to use code outside of transformers module.', + ) return parser.parse_args() @@ -236,4 +257,5 @@ def parse_args() -> Namespace: save_dir=save_dir, output_precision=args.output_precision, local_checkpoint_save_location=args.local_checkpoint_save_location, - trust_remote_code=args.trust_remote_code) + trust_remote_code=args.trust_remote_code, + ) diff --git a/scripts/inference/convert_composer_to_hf.py b/scripts/inference/convert_composer_to_hf.py index 51afb105c8..4d4019208c 100644 --- a/scripts/inference/convert_composer_to_hf.py +++ b/scripts/inference/convert_composer_to_hf.py @@ -10,8 +10,12 @@ import torch import transformers from composer.models.huggingface import get_hf_config_from_composer_state_dict -from composer.utils import (get_file, maybe_create_object_store_from_uri, - parse_uri, safe_torch_load) +from composer.utils import ( + get_file, + maybe_create_object_store_from_uri, + parse_uri, + safe_torch_load, +) from transformers import PretrainedConfig, PreTrainedTokenizerBase from llmfoundry import MPTConfig, MPTForCausalLM @@ -26,7 +30,7 @@ def write_huggingface_pretrained_from_composer_checkpoint( output_path: Union[Path, str], trust_remote_code: bool, output_precision: str = 'fp32', - local_checkpoint_save_location: Optional[Union[Path, str]] = None + local_checkpoint_save_location: Optional[Union[Path, str]] = None, ) -> Tuple[PretrainedConfig, Optional[PreTrainedTokenizerBase]]: """Convert a Composer checkpoint to a pretrained HF checkpoint folder. @@ -81,14 +85,15 @@ def write_huggingface_pretrained_from_composer_checkpoint( if local_checkpoint_save_location is None: tmp_dir = tempfile.TemporaryDirectory() local_checkpoint_save_location = Path( - tmp_dir.name) / 'local-composer-checkpoint.pt' + tmp_dir.name, + ) / 'local-composer-checkpoint.pt' # create folder os.makedirs(output_path) # download the checkpoint file print( - f'Downloading checkpoint from {checkpoint_path} -> {local_checkpoint_save_location}' + f'Downloading checkpoint from {checkpoint_path} -> {local_checkpoint_save_location}', ) get_file(str(checkpoint_path), str(local_checkpoint_save_location)) @@ -98,7 +103,7 @@ def write_huggingface_pretrained_from_composer_checkpoint( if 'state' not in composer_state_dict: raise RuntimeError( - f'"state" is not an available key in the provided composer checkpoint. Is {local_checkpoint_save_location} ill-formed?' + f'"state" is not an available key in the provided composer checkpoint. Is {local_checkpoint_save_location} ill-formed?', ) # Build and save HF Config @@ -113,7 +118,9 @@ def write_huggingface_pretrained_from_composer_checkpoint( print('#' * 30) print('Saving HF Tokenizer...') hf_tokenizer = get_hf_tokenizer_from_composer_state_dict( - composer_state_dict, trust_remote_code) + composer_state_dict, + trust_remote_code, + ) if hf_tokenizer is not None: hf_tokenizer.save_pretrained(output_path) print(hf_tokenizer) @@ -127,7 +134,9 @@ def write_huggingface_pretrained_from_composer_checkpoint( if 'state' in weights_state_dict: weights_state_dict = weights_state_dict['state']['model'] torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( - weights_state_dict, prefix='model.') + weights_state_dict, + prefix='model.', + ) # Convert weights to desired dtype for k, v in weights_state_dict.items(): @@ -147,23 +156,28 @@ def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( description= - 'Convert a HuggingFace causal LM in a Composer checkpoint into a standard HuggingFace checkpoint folder, and optionally upload to the hub.' + 'Convert a HuggingFace causal LM in a Composer checkpoint into a standard HuggingFace checkpoint folder, and optionally upload to the hub.', ) parser.add_argument('--composer_path', type=str, required=True) parser.add_argument('--hf_output_path', type=str, required=True) - parser.add_argument('--local_checkpoint_save_location', - type=str, - default=None) - parser.add_argument('--output_precision', - type=str, - choices=['fp32', 'fp16', 'bf16'], - default='fp32') + parser.add_argument( + '--local_checkpoint_save_location', + type=str, + default=None, + ) + parser.add_argument( + '--output_precision', + type=str, + choices=['fp32', 'fp16', 'bf16'], + default='fp32', + ) parser.add_argument('--hf_repo_for_upload', type=str, default=None) parser.add_argument('--test_uploaded_model', action='store_true') parser.add_argument( '--trust_remote_code', action='store_true', - help='Whether or not to use code outside of transformers module.') + help='Whether or not to use code outside of transformers module.', + ) return parser.parse_args() @@ -180,7 +194,8 @@ def _convert_composer_to_hf(args: Namespace) -> None: output_path=local_folder_path, trust_remote_code=args.trust_remote_code, output_precision=args.output_precision, - local_checkpoint_save_location=args.local_checkpoint_save_location) + local_checkpoint_save_location=args.local_checkpoint_save_location, + ) dtype = { 'fp32': torch.float32, @@ -194,12 +209,17 @@ def _convert_composer_to_hf(args: Namespace) -> None: config.init_device = 'cpu' if config.model_type == 'mpt': - loaded_hf_model = MPTForCausalLM.from_pretrained(local_folder_path, - config=config, - torch_dtype=dtype) + loaded_hf_model = MPTForCausalLM.from_pretrained( + local_folder_path, + config=config, + torch_dtype=dtype, + ) else: loaded_hf_model = transformers.AutoModelForCausalLM.from_pretrained( - local_folder_path, config=config, torch_dtype=dtype) + local_folder_path, + config=config, + torch_dtype=dtype, + ) delattr(loaded_hf_model.config, '_name_or_path') @@ -207,8 +227,10 @@ def _convert_composer_to_hf(args: Namespace) -> None: print(f'Loading tokenizer from {local_folder_path}') - tokenizer = load_tokenizer(local_folder_path, - trust_remote_code=args.trust_remote_code) + tokenizer = load_tokenizer( + local_folder_path, + trust_remote_code=args.trust_remote_code, + ) tokenizer.save_pretrained(local_folder_path) # Only need to edit files for MPT because it has custom code @@ -220,7 +242,7 @@ def _convert_composer_to_hf(args: Namespace) -> None: if object_store is not None: print( - f'Uploading HF checkpoint folder from {local_folder_path} -> {args.hf_output_path}' + f'Uploading HF checkpoint folder from {local_folder_path} -> {args.hf_output_path}', ) for file in os.listdir(local_folder_path): remote_file = os.path.join(local_folder_path, file) @@ -232,27 +254,32 @@ def _convert_composer_to_hf(args: Namespace) -> None: api = HfApi() print( - f'Uploading {args.hf_output_path} to HuggingFace Hub at {args.hf_repo_for_upload}' + f'Uploading {args.hf_output_path} to HuggingFace Hub at {args.hf_repo_for_upload}', + ) + api.create_repo( + repo_id=args.hf_repo_for_upload, + use_auth_token=True, + repo_type='model', + private=True, + exist_ok=True, ) - api.create_repo(repo_id=args.hf_repo_for_upload, - use_auth_token=True, - repo_type='model', - private=True, - exist_ok=True) print('Repo created.') # ignore the full checkpoint file if we now have sharded checkpoint files ignore_patterns = [] if any( - f.startswith('pytorch_model-00001') - for f in os.listdir(args.hf_output_path)): + f.startswith('pytorch_model-00001') + for f in os.listdir(args.hf_output_path) + ): ignore_patterns.append('pytorch_model.bin') - api.upload_folder(folder_path=args.hf_output_path, - repo_id=args.hf_repo_for_upload, - use_auth_token=True, - repo_type='model', - ignore_patterns=ignore_patterns) + api.upload_folder( + folder_path=args.hf_output_path, + repo_id=args.hf_repo_for_upload, + use_auth_token=True, + repo_type='model', + ignore_patterns=ignore_patterns, + ) print('Folder uploaded.') if args.test_uploaded_model: @@ -261,30 +288,38 @@ def _convert_composer_to_hf(args: Namespace) -> None: args.hf_repo_for_upload, trust_remote_code=True, use_auth_token=True, - torch_dtype=dtype) + torch_dtype=dtype, + ) hub_tokenizer = transformers.AutoTokenizer.from_pretrained( args.hf_repo_for_upload, trust_remote_code=True, - use_auth_token=True) + use_auth_token=True, + ) - assert sum(p.numel() for p in hub_model.parameters()) == sum( - p.numel() for p in loaded_hf_model.parameters()) + assert sum(p.numel() for p in hub_model.parameters() + ) == sum(p.numel() for p in loaded_hf_model.parameters()) assert all( - str(type(module1)).split('.')[-2:] == str(type(module2)).split( - '.')[-2:] for module1, module2 in zip( - hub_model.modules(), loaded_hf_model.modules())) + str(type(module1)).split('.')[-2:] == str( + type(module2), + ).split('.')[-2:] for module1, module2 in + zip(hub_model.modules(), loaded_hf_model.modules()) + ) assert next( - hub_model.parameters() + hub_model.parameters(), ).dtype == dtype, f'Expected model dtype to be {dtype}, but got {next(hub_model.parameters()).dtype}' print( hub_tokenizer.batch_decode( - hub_model.generate(hub_tokenizer( - 'MosaicML is', return_tensors='pt').input_ids, - max_new_tokens=10))) + hub_model.generate( + hub_tokenizer('MosaicML is', + return_tensors='pt').input_ids, + max_new_tokens=10, + ), + ), + ) print( - 'Composer checkpoint successfully converted to HuggingFace checkpoint format.' + 'Composer checkpoint successfully converted to HuggingFace checkpoint format.', ) diff --git a/scripts/inference/convert_hf_mpt_to_ft.py b/scripts/inference/convert_hf_mpt_to_ft.py index 104d0d6b15..ada9e48f9e 100644 --- a/scripts/inference/convert_hf_mpt_to_ft.py +++ b/scripts/inference/convert_hf_mpt_to_ft.py @@ -30,11 +30,13 @@ from llmfoundry.utils import convert_and_save_ft_weights -def convert_mpt_to_ft(model_name_or_path: str, - output_dir: str, - infer_gpu_num: int = 1, - weight_data_type: str = 'fp32', - force: bool = False) -> None: +def convert_mpt_to_ft( + model_name_or_path: str, + output_dir: str, + infer_gpu_num: int = 1, + weight_data_type: str = 'fp32', + force: bool = False, +) -> None: """Convert an MPT checkpoint to a FasterTransformer compatible format. Args: @@ -56,9 +58,13 @@ def convert_mpt_to_ft(model_name_or_path: str, torch_device = 'cpu' model = transformers.AutoModelForCausalLM.from_pretrained( - model_name_or_path, trust_remote_code=True).to(torch_device) + model_name_or_path, + trust_remote_code=True, + ).to(torch_device) tokenizer = transformers.AutoTokenizer.from_pretrained( - model_name_or_path, trust_remote_code=True) + model_name_or_path, + trust_remote_code=True, + ) hf_config = vars(model.config) @@ -75,10 +81,10 @@ def convert_mpt_to_ft(model_name_or_path: str, config['gpt']['num_layer'] = str(hf_config['n_layers']) config['gpt']['vocab_size'] = str(hf_config['vocab_size']) config['gpt']['start_id'] = str( - hf_config['bos_token_id'] + hf_config['bos_token_id'], ) if hf_config['bos_token_id'] != None else str(tokenizer.bos_token_id) config['gpt']['end_id'] = str( - hf_config['eos_token_id'] + hf_config['eos_token_id'], ) if hf_config['eos_token_id'] != None else str(tokenizer.eos_token_id) config['gpt']['weight_data_type'] = weight_data_type config['gpt']['tensor_para_size'] = str(infer_gpu_num) @@ -89,11 +95,11 @@ def convert_mpt_to_ft(model_name_or_path: str, config['gpt']['use_attention_linear_bias'] = str(True) if hf_config['attn_config']['clip_qkv'] and not force: raise RuntimeError( - 'clip_qkv is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.' + 'clip_qkv is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.', ) if hf_config['attn_config']['qk_ln'] and not force: raise RuntimeError( - 'qk_ln is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.' + 'qk_ln is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.', ) with open(os.path.join(save_dir, 'config.ini'), 'w') as configfile: @@ -102,47 +108,55 @@ def convert_mpt_to_ft(model_name_or_path: str, print(f'Failed to save the config in config.ini.') raise - named_params_dict = { - name: param for name, param in model.named_parameters() - } - convert_and_save_ft_weights(named_params=named_params_dict, - config=hf_config, - infer_gpu_num=infer_gpu_num, - weight_data_type=weight_data_type, - save_dir=save_dir) + named_params_dict = dict(model.named_parameters()) + convert_and_save_ft_weights( + named_params=named_params_dict, + config=hf_config, + infer_gpu_num=infer_gpu_num, + weight_data_type=weight_data_type, + save_dir=save_dir, + ) if __name__ == '__main__': parser = argparse.ArgumentParser( - formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('--save_dir', - '-o', - type=str, - help='Directory to save converted checkpoint in', - required=True) + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + '--save_dir', + '-o', + type=str, + help='Directory to save converted checkpoint in', + required=True, + ) parser.add_argument( '--name_or_dir', '-i', type=str, help= 'HF hub Model name (e.g., mosaicml/mpt-7b) or local dir path to load checkpoint from', - required=True) - parser.add_argument('--infer_gpu_num', - '-i_g', - type=int, - help='How many gpus for inference?', - required=True) + required=True, + ) + parser.add_argument( + '--infer_gpu_num', + '-i_g', + type=int, + help='How many gpus for inference?', + required=True, + ) parser.add_argument( '--force', action='store_true', help= - 'Force conversion to FT even if some features may not work as expected in FT' + 'Force conversion to FT even if some features may not work as expected in FT', + ) + parser.add_argument( + '--weight_data_type', + type=str, + help='Data type of weights in the input checkpoint', + default='fp32', + choices=['fp32', 'fp16'], ) - parser.add_argument('--weight_data_type', - type=str, - help='Data type of weights in the input checkpoint', - default='fp32', - choices=['fp32', 'fp16']) args = parser.parse_args() print('\n=============== Argument ===============') @@ -150,5 +164,10 @@ def convert_mpt_to_ft(model_name_or_path: str, print('{}: {}'.format(key, vars(args)[key])) print('========================================') - convert_mpt_to_ft(args.name_or_dir, args.save_dir, args.infer_gpu_num, - args.weight_data_type, args.force) + convert_mpt_to_ft( + args.name_or_dir, + args.save_dir, + args.infer_gpu_num, + args.weight_data_type, + args.force, + ) diff --git a/scripts/inference/convert_hf_to_onnx.py b/scripts/inference/convert_hf_to_onnx.py index 9d1841b12f..f230e56bad 100644 --- a/scripts/inference/convert_hf_to_onnx.py +++ b/scripts/inference/convert_hf_to_onnx.py @@ -33,8 +33,11 @@ from typing import Any, Dict, Optional, Union import torch -from composer.utils import (maybe_create_object_store_from_uri, parse_uri, - reproducibility) +from composer.utils import ( + maybe_create_object_store_from_uri, + parse_uri, + reproducibility, +) from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer @@ -71,7 +74,7 @@ def gen_random_batch(batch_size: int, vocab_size: int, max_seq_len: int): dtype=torch.int64, ), 'attention_mask': - torch.ones(size=(batch_size, max_seq_len), dtype=torch.bool) + torch.ones(size=(batch_size, max_seq_len), dtype=torch.bool), } return batch @@ -89,23 +92,29 @@ def export_to_onnx( _, _, parsed_save_path = parse_uri(output_folder) print('Loading HF config/model/tokenizer...') - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, - **from_pretrained_kwargs) - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, - **from_pretrained_kwargs) + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + **from_pretrained_kwargs, + ) + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + **from_pretrained_kwargs, + ) # specifically for MPT, switch to the torch version of attention for ONNX export if hasattr(config, 'attn_config'): config.attn_config['attn_impl'] = 'torch' - model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, - config=config, - **from_pretrained_kwargs) + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + config=config, + **from_pretrained_kwargs, + ) model.eval() if max_seq_len is None and not hasattr(model.config, 'max_seq_len'): raise ValueError( - 'max_seq_len must be specified in either the model config or as an argument to this function.' + 'max_seq_len must be specified in either the model config or as an argument to this function.', ) elif max_seq_len is None: max_seq_len = model.config.max_seq_len @@ -195,16 +204,20 @@ def parse_args(): '--verify_export', action='store_true', ) - parser.add_argument('--trust_remote_code', - type=str2bool, - nargs='?', - const=True, - default=True) - parser.add_argument('--use_auth_token', - type=str_or_bool, - nargs='?', - const=True, - default=None) + parser.add_argument( + '--trust_remote_code', + type=str2bool, + nargs='?', + const=True, + default=True, + ) + parser.add_argument( + '--use_auth_token', + type=str_or_bool, + nargs='?', + const=True, + default=None, + ) parser.add_argument('--revision', type=str, default=None) return parser.parse_args() @@ -222,7 +235,8 @@ def main(args: argparse.Namespace): export_batch_size=args.export_batch_size, max_seq_len=args.max_seq_len, verify_export=args.verify_export, - from_pretrained_kwargs=from_pretrained_kwargs) + from_pretrained_kwargs=from_pretrained_kwargs, + ) if __name__ == '__main__': diff --git a/scripts/inference/endpoint_generate.py b/scripts/inference/endpoint_generate.py index e6f9ae1448..9991f5093f 100644 --- a/scripts/inference/endpoint_generate.py +++ b/scripts/inference/endpoint_generate.py @@ -17,8 +17,11 @@ import pandas as pd import requests -from composer.utils import (get_file, maybe_create_object_store_from_uri, - parse_uri) +from composer.utils import ( + get_file, + maybe_create_object_store_from_uri, + parse_uri, +) from llmfoundry.utils import prompt_files as utils @@ -34,7 +37,8 @@ def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( - description='Call prompts against a text completions endpoint') + description='Call prompts against a text completions endpoint', + ) ##### # Path Parameters @@ -43,18 +47,21 @@ def parse_args() -> Namespace: '--inputs', nargs='+', help=f'List of strings, local data files (starting with {utils.PROMPTFILE_PREFIX}),' +\ - ' and/or remote object stores' + ' and/or remote object stores', ) parser.add_argument( '--prompt-delimiter', default='\n', help= - 'Prompt delimiter for txt files. By default, a file is a single prompt') + 'Prompt delimiter for txt files. By default, a file is a single prompt', + ) - parser.add_argument('-o', - '--output-folder', - required=True, - help='Remote folder to save the output') + parser.add_argument( + '-o', + '--output-folder', + required=True, + help='Remote folder to save the output', + ) ##### # Generation Parameters @@ -62,12 +69,14 @@ def parse_args() -> Namespace: '--rate-limit', type=int, default=75, - help='Max number of calls to make to the endpoint in a second') + help='Max number of calls to make to the endpoint in a second', + ) parser.add_argument( '--batch-size', type=int, default=10, - help='Max number of calls to make to the endpoint in a single request') + help='Max number of calls to make to the endpoint in a single request', + ) ##### # Endpoint Parameters @@ -76,7 +85,7 @@ def parse_args() -> Namespace: '--endpoint', type=str, help= - f'OpenAI-compatible text completions endpoint to query on. If not set, will read from {ENDPOINT_URL_ENV}' + f'OpenAI-compatible text completions endpoint to query on. If not set, will read from {ENDPOINT_URL_ENV}', ) parser.add_argument('--max-tokens', type=int, default=100) @@ -100,13 +109,14 @@ async def main(args: Namespace) -> None: if args.batch_size > args.rate_limit: raise ValueError( - f'Batch size is {args.batch_size} but rate limit is set to {args.rate_limit} / s' + f'Batch size is {args.batch_size} but rate limit is set to {args.rate_limit} / s', ) url = args.endpoint if args.endpoint else os.environ.get(ENDPOINT_URL_ENV) if not url: raise ValueError( - f'URL must be provided via --endpoint or {ENDPOINT_URL_ENV}') + f'URL must be provided via --endpoint or {ENDPOINT_URL_ENV}', + ) log.info(f'Using endpoint {url}') @@ -141,12 +151,16 @@ async def main(args: Namespace) -> None: total_batches = math.ceil(len(prompt_strings) / args.batch_size) log.info( - f'Generating {len(prompt_strings)} prompts in {total_batches} batches') + f'Generating {len(prompt_strings)} prompts in {total_batches} batches', + ) @sleep_and_retry @limits(calls=total_batches, period=1) # type: ignore - async def generate(session: aiohttp.ClientSession, batch: int, - prompts: list): + async def generate( + session: aiohttp.ClientSession, + batch: int, + prompts: list, + ): data = copy.copy(param_data) data['prompt'] = prompts headers = {'Authorization': api_key, 'Content-Type': 'application/json'} @@ -157,7 +171,8 @@ async def generate(session: aiohttp.ClientSession, batch: int, response = await resp.json() except requests.JSONDecodeError: raise Exception( - f'Bad response: {resp.status} {resp.reason}') + f'Bad response: {resp.status} {resp.reason}', + ) else: raise Exception(f'Bad response: {resp.status} {resp.reason}') @@ -165,8 +180,10 @@ async def generate(session: aiohttp.ClientSession, batch: int, n_compl = response['usage']['completion_tokens'] n_prompt = response['usage']['prompt_tokens'] req_latency = (req_end - req_start) - log.info(f'Completed batch {batch}: {n_compl:,} completion' + - f' tokens using {n_prompt:,} prompt tokens in {req_latency}s') + log.info( + f'Completed batch {batch}: {n_compl:,} completion' + + f' tokens using {n_prompt:,} prompt tokens in {req_latency}s', + ) res = pd.DataFrame(columns=cols) @@ -183,8 +200,9 @@ async def generate(session: aiohttp.ClientSession, batch: int, tasks = [] for i in range(total_batches): - prompts = prompt_strings[i * args.batch_size:min( - (i + 1) * args.batch_size, len(prompt_strings))] + prompts = prompt_strings[ + i * args.batch_size:min((i + 1) * + args.batch_size, len(prompt_strings))] tasks.append(generate(session, batch, prompts)) batch += 1 @@ -205,7 +223,8 @@ async def generate(session: aiohttp.ClientSession, batch: int, res.to_csv(local_path, index=False) output_object_store = maybe_create_object_store_from_uri( - args.output_folder) + args.output_folder, + ) if output_object_store is not None: _, _, output_folder_prefix = parse_uri(args.output_folder) remote_path = os.path.join(output_folder_prefix, file) diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index ab89364e30..3657bbe1b0 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -8,9 +8,16 @@ from typing import Any, Dict, List, Optional, Union import torch -from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - PreTrainedModel, PreTrainedTokenizerBase, - StoppingCriteria, StoppingCriteriaList, TextStreamer) +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, + StoppingCriteria, + StoppingCriteriaList, + TextStreamer, +) DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.' @@ -54,12 +61,14 @@ class Conversation: cli_instructions: The instructions to display to the user. """ - def __init__(self, - model: PreTrainedModel, - tokenizer: PreTrainedTokenizerBase, - generate_kwargs: Dict[str, Any], - system_prompt: str, - stop_tokens: Optional[List[str]] = None) -> None: + def __init__( + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + generate_kwargs: Dict[str, Any], + system_prompt: str, + stop_tokens: Optional[List[str]] = None, + ) -> None: if stop_tokens is None: stop_tokens = ['<|endoftext|>', '<|im_end|>'] self.model = model @@ -69,21 +78,28 @@ def __init__(self, if len(stop_token_ids) != len(stop_tokens): warnings.warn( f'Not all stop tokens were found in the tokenizer vocabulary: {stop_tokens}\n' - + 'Generation may stop or continue unexpectedly.') + + 'Generation may stop or continue unexpectedly.', + ) class StopOnTokens(StoppingCriteria): - def __call__(self, input_ids: torch.LongTensor, - scores: torch.FloatTensor, **kwargs: Any) -> bool: + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs: Any, + ) -> bool: del kwargs # unused for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False - self.streamer = TextStreamer(tokenizer, - skip_prompt=True, - skip_special_tokens=True) + self.streamer = TextStreamer( + tokenizer, + skip_prompt=True, + skip_special_tokens=True, + ) self.generate_kwargs = { **generate_kwargs, 'stopping_criteria': @@ -104,9 +120,7 @@ def __call__(self, input_ids: torch.LongTensor, ) def _history_to_chat_conversation(self) -> List[Dict[str, str]]: - msg_history = [] - for chat_msg in self.history: - msg_history.append(chat_msg.to_dict()) + msg_history = [chat_msg.to_dict() for chat_msg in self.history] return msg_history def _history_as_formatted_str(self) -> str: @@ -124,7 +138,8 @@ def turn(self, user_inp: str) -> None: chat_conversation, tokenize=True, add_generation_prompt=True, - return_tensors='pt') + return_tensors='pt', + ) tokenized_chat = tokenized_chat.to(self.model.device) # also stream to stdout maybe_synchronize() @@ -135,8 +150,10 @@ def turn(self, user_inp: str) -> None: end = time.time() print(f'\nTook {end - start:.2f} seconds') new_tokens = output_ids[0, len(tokenized_chat[0]):] - assistant_response = self.tokenizer.decode(new_tokens, - skip_special_tokens=True) + assistant_response = self.tokenizer.decode( + new_tokens, + skip_special_tokens=True, + ) self.history.append(ChatMessage('assistant', assistant_response)) def __call__(self) -> None: @@ -179,7 +196,8 @@ def get_dtype(dtype: str): else: raise NotImplementedError( f'dtype {dtype} is not supported. ' + - 'We only support fp32, fp16, and bf16 currently') + 'We only support fp32, fp16, and bf16 currently', + ) def str2bool(v: Union[str, bool]): @@ -207,61 +225,79 @@ def str_or_bool(v: Union[str, bool]): def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( - description='Load a HF CausalLM Model and use it to generate text.') + description='Load a HF CausalLM Model and use it to generate text.', + ) parser.add_argument('-n', '--name_or_path', type=str, required=True) parser.add_argument('--max_new_tokens', type=int, default=512) parser.add_argument('--max_seq_len', type=int, default=None) parser.add_argument('--temperature', type=float, default=1.0) parser.add_argument('--top_k', type=int, default=50) parser.add_argument('--top_p', type=float, default=1.0) - parser.add_argument('--do_sample', - type=str2bool, - nargs='?', - const=True, - default=True) - parser.add_argument('--use_cache', - type=str2bool, - nargs='?', - const=True, - default=True) + parser.add_argument( + '--do_sample', + type=str2bool, + nargs='?', + const=True, + default=True, + ) + parser.add_argument( + '--use_cache', + type=str2bool, + nargs='?', + const=True, + default=True, + ) parser.add_argument('--eos_token_id', type=str, default=None) parser.add_argument('--pad_token_id', type=str, default=None) - parser.add_argument('--model_dtype', - type=str, - choices=['fp32', 'fp16', 'bf16'], - default=None) - parser.add_argument('--autocast_dtype', - type=str, - choices=['fp32', 'fp16', 'bf16'], - default=None) - parser.add_argument('--warmup', - type=str2bool, - nargs='?', - const=True, - default=True) - parser.add_argument('--trust_remote_code', - type=str2bool, - nargs='?', - const=True, - default=True) - parser.add_argument('--use_auth_token', - type=str_or_bool, - nargs='?', - const=True, - default=None) + parser.add_argument( + '--model_dtype', + type=str, + choices=['fp32', 'fp16', 'bf16'], + default=None, + ) + parser.add_argument( + '--autocast_dtype', + type=str, + choices=['fp32', 'fp16', 'bf16'], + default=None, + ) + parser.add_argument( + '--warmup', + type=str2bool, + nargs='?', + const=True, + default=True, + ) + parser.add_argument( + '--trust_remote_code', + type=str2bool, + nargs='?', + const=True, + default=True, + ) + parser.add_argument( + '--use_auth_token', + type=str_or_bool, + nargs='?', + const=True, + default=None, + ) parser.add_argument('--revision', type=str, default=None) parser.add_argument('--device', type=str, default=None) parser.add_argument('--device_map', type=str, default=None) parser.add_argument('--attn_impl', type=str, default=None) parser.add_argument('--seed', type=int, default=42) - parser.add_argument('--system_prompt', - type=str, - default=DEFAULT_SYSTEM_PROMPT) + parser.add_argument( + '--system_prompt', + type=str, + default=DEFAULT_SYSTEM_PROMPT, + ) parser.add_argument( '--stop_tokens', type=str, default='<|endoftext|> <|im_end|>', - help='A string of tokens to stop generation on; will be split on spaces.' + help= + 'A string of tokens to stop generation on; will be split on spaces.', ) return parser.parse_args() @@ -298,8 +334,10 @@ def main(args: Namespace) -> None: 'revision': args.revision, } try: - config = AutoConfig.from_pretrained(args.name_or_path, - **from_pretrained_kwargs) + config = AutoConfig.from_pretrained( + args.name_or_path, + **from_pretrained_kwargs, + ) if args.attn_impl is not None and hasattr(config, 'attn_config'): config.attn_config['attn_impl'] = args.attn_impl if hasattr(config, 'init_device') and device is not None: @@ -313,17 +351,19 @@ def main(args: Namespace) -> None: + 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' + - 'using your access token from https://huggingface.co/settings/tokens.' + 'using your access token from https://huggingface.co/settings/tokens.', ) from e # Load HF Model print(f'Loading HF model with dtype={model_dtype}...') try: - model = AutoModelForCausalLM.from_pretrained(args.name_or_path, - config=config, - torch_dtype=model_dtype, - device_map=device_map, - **from_pretrained_kwargs) + model = AutoModelForCausalLM.from_pretrained( + args.name_or_path, + config=config, + torch_dtype=model_dtype, + device_map=device_map, + **from_pretrained_kwargs, + ) model.eval() print(f'n_params={sum(p.numel() for p in model.parameters())}') if device is not None: @@ -336,15 +376,17 @@ def main(args: Namespace) -> None: + 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' + - 'using your access token from https://huggingface.co/settings/tokens.' + 'using your access token from https://huggingface.co/settings/tokens.', ) from e print('\nLoading HF tokenizer...') - tokenizer = AutoTokenizer.from_pretrained(args.name_or_path, - **from_pretrained_kwargs) + tokenizer = AutoTokenizer.from_pretrained( + args.name_or_path, + **from_pretrained_kwargs, + ) if tokenizer.pad_token_id is None: warnings.warn( - 'pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.' + 'pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.', ) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = 'left' @@ -368,19 +410,21 @@ def main(args: Namespace) -> None: autocast_context = nullcontext() print('NOT using autocast...') - conversation = Conversation(model=model, - tokenizer=tokenizer, - system_prompt=args.system_prompt, - generate_kwargs=generate_kwargs, - stop_tokens=args.stop_tokens.split()) + conversation = Conversation( + model=model, + tokenizer=tokenizer, + system_prompt=args.system_prompt, + generate_kwargs=generate_kwargs, + stop_tokens=args.stop_tokens.split(), + ) # Warmup if args.warmup: print('Warming up...') with autocast_context: conversation.turn('Write a welcome message to the user.') - conversation.history = conversation.history[: - 1] # keep system prompt + conversation.history = conversation.history[:1 + ] # keep system prompt print('Starting conversation...') with autocast_context: diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index 57193136ec..eab46d7a69 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -53,7 +53,8 @@ def str_or_bool(v: Union[str, bool]): def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( - description='Load a HF CausalLM Model and use it to generate text.') + description='Load a HF CausalLM Model and use it to generate text.', + ) parser.add_argument('-n', '--name_or_path', type=str, required=True) parser.add_argument( '-p', @@ -64,13 +65,14 @@ def parse_args() -> Namespace: 'This is an explanation of deep learning to a five year old. Deep learning is', ], help='List of generation prompts or list of delimited files. Use syntax ' +\ - '"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.' + '"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.', ) parser.add_argument( '--prompt-delimiter', default=None, help= - 'Prompt delimiter for txt files. By default, a file is a single prompt') + 'Prompt delimiter for txt files. By default, a file is a single prompt', + ) parser.add_argument('--max_seq_len', type=int, default=None) parser.add_argument('--max_new_tokens', type=int, default=100) parser.add_argument('--max_batch_size', type=int, default=None) @@ -79,51 +81,69 @@ def parse_args() -> Namespace: parser.add_argument('--temperature', type=float, nargs='+', default=[1.0]) parser.add_argument('--top_k', type=int, nargs='+', default=[50]) parser.add_argument('--top_p', type=float, nargs='+', default=[1.0]) - parser.add_argument('--repetition_penalty', - type=float, - nargs='+', - default=[1.0]) - parser.add_argument('--no_repeat_ngram_size', - type=int, - nargs='+', - default=[0]) + parser.add_argument( + '--repetition_penalty', + type=float, + nargs='+', + default=[1.0], + ) + parser.add_argument( + '--no_repeat_ngram_size', + type=int, + nargs='+', + default=[0], + ) ##### parser.add_argument('--seed', type=int, nargs='+', default=[42]) - parser.add_argument('--do_sample', - type=str2bool, - nargs='?', - const=True, - default=True) - parser.add_argument('--use_cache', - type=str2bool, - nargs='?', - const=True, - default=True) + parser.add_argument( + '--do_sample', + type=str2bool, + nargs='?', + const=True, + default=True, + ) + parser.add_argument( + '--use_cache', + type=str2bool, + nargs='?', + const=True, + default=True, + ) parser.add_argument('--eos_token_id', type=int, default=None) parser.add_argument('--pad_token_id', type=int, default=None) - parser.add_argument('--model_dtype', - type=str, - choices=['fp32', 'fp16', 'bf16'], - default=None) - parser.add_argument('--autocast_dtype', - type=str, - choices=['fp32', 'fp16', 'bf16'], - default=None) - parser.add_argument('--warmup', - type=str2bool, - nargs='?', - const=True, - default=True) - parser.add_argument('--trust_remote_code', - type=str2bool, - nargs='?', - const=True, - default=True) - parser.add_argument('--use_auth_token', - type=str_or_bool, - nargs='?', - const=True, - default=None) + parser.add_argument( + '--model_dtype', + type=str, + choices=['fp32', 'fp16', 'bf16'], + default=None, + ) + parser.add_argument( + '--autocast_dtype', + type=str, + choices=['fp32', 'fp16', 'bf16'], + default=None, + ) + parser.add_argument( + '--warmup', + type=str2bool, + nargs='?', + const=True, + default=True, + ) + parser.add_argument( + '--trust_remote_code', + type=str2bool, + nargs='?', + const=True, + default=True, + ) + parser.add_argument( + '--use_auth_token', + type=str_or_bool, + nargs='?', + const=True, + default=None, + ) parser.add_argument('--revision', type=str, default=None) parser.add_argument('--device', type=str, default=None) parser.add_argument('--device_map', type=str, default=None) @@ -166,8 +186,10 @@ def main(args: Namespace) -> None: 'revision': args.revision, } try: - config = AutoConfig.from_pretrained(args.name_or_path, - **from_pretrained_kwargs) + config = AutoConfig.from_pretrained( + args.name_or_path, + **from_pretrained_kwargs, + ) if hasattr(config, 'init_device') and device is not None: config.init_device = device if args.attn_impl is not None and hasattr(config, 'attn_config'): @@ -179,16 +201,18 @@ def main(args: Namespace) -> None: raise RuntimeError( 'If you are having auth problems, try logging in via `huggingface-cli login` ' +\ 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' +\ - 'using your access token from https://huggingface.co/settings/tokens.' + 'using your access token from https://huggingface.co/settings/tokens.', ) from e # Build tokenizer print('\nLoading HF tokenizer...') - tokenizer = AutoTokenizer.from_pretrained(args.name_or_path, - **from_pretrained_kwargs) + tokenizer = AutoTokenizer.from_pretrained( + args.name_or_path, + **from_pretrained_kwargs, + ) if tokenizer.pad_token_id is None: warnings.warn( - 'pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.' + 'pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.', ) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = 'left' @@ -196,11 +220,13 @@ def main(args: Namespace) -> None: # Load HF Model print(f'Loading HF model with dtype={model_dtype}...') try: - model = AutoModelForCausalLM.from_pretrained(args.name_or_path, - config=config, - torch_dtype=model_dtype, - device_map=device_map, - **from_pretrained_kwargs) + model = AutoModelForCausalLM.from_pretrained( + args.name_or_path, + config=config, + torch_dtype=model_dtype, + device_map=device_map, + **from_pretrained_kwargs, + ) model.eval() print(f'n_params={sum(p.numel() for p in model.parameters())}') if device is not None: @@ -213,7 +239,7 @@ def main(args: Namespace) -> None: + 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' + - 'using your access token from https://huggingface.co/settings/tokens.' + 'using your access token from https://huggingface.co/settings/tokens.', ) from e # Autocast @@ -228,8 +254,13 @@ def main(args: Namespace) -> None: done_warmup = False for temp, topp, topk, repp, nrnz, seed in itertools.product( - args.temperature, args.top_p, args.top_k, args.repetition_penalty, - args.no_repeat_ngram_size, args.seed): + args.temperature, + args.top_p, + args.top_k, + args.repetition_penalty, + args.no_repeat_ngram_size, + args.seed, + ): # Seed randomness random.seed(seed) @@ -284,7 +315,8 @@ def _generate(encoded_inp: Dict[str, torch.Tensor]): input_tokens = torch.sum( encoded_inp['input_ids'] != tokenizer.pad_token_id, # type: ignore - axis=1).numpy(force=True) + axis=1, + ).numpy(force=True) # Warmup if args.warmup and (not done_warmup): @@ -301,21 +333,28 @@ def _generate(encoded_inp: Dict[str, torch.Tensor]): gen_end = time.time() decode_start = time.time() - decoded_gen = tokenizer.batch_decode(encoded_gen, - skip_special_tokens=True) + decoded_gen = tokenizer.batch_decode( + encoded_gen, + skip_special_tokens=True, + ) maybe_synchronize() decode_end = time.time() - gen_tokens = torch.sum(encoded_gen != tokenizer.pad_token_id, - axis=1).numpy(force=True) # type: ignore + gen_tokens = torch.sum( + encoded_gen != tokenizer.pad_token_id, + axis=1, + ).numpy(force=True) # type: ignore # Print generations delimiter = '#' * 100 # decode the encoded prompt to handle the case when the tokenizer # trims extra spaces or does other pre-tokenization things - effective_prompts = tokenizer.batch_decode(encoded_inp['input_ids'], - skip_special_tokens=True) + effective_prompts = tokenizer.batch_decode( + encoded_inp['input_ids'], + skip_special_tokens=True, + ) for idx, (effective_prompt, prompt, gen) in enumerate( - zip(effective_prompts, batch, decoded_gen)): + zip(effective_prompts, batch, decoded_gen), + ): continuation = gen[len(effective_prompt):] print(delimiter) if len(continuation) > 0: @@ -323,11 +362,13 @@ def _generate(encoded_inp: Dict[str, torch.Tensor]): else: print('Warning. No non-special output tokens generated.') print( - 'This can happen if the generation only contains padding/eos tokens.' + 'This can happen if the generation only contains padding/eos tokens.', ) print('Debug:') full_generation = tokenizer.batch_decode( - encoded_gen, skip_special_tokens=False)[idx] + encoded_gen, + skip_special_tokens=False, + )[idx] print('\033[92m' + 'Prompt:\n' + prompt + '\033[0m') print('Full generation:\n' + full_generation) @@ -351,7 +392,7 @@ def _generate(encoded_inp: Dict[str, torch.Tensor]): print(f'{bs=}, {input_tokens=}, {output_tokens=}') print(f'{total_input_tokens=}, {total_output_tokens=}') print( - f'{encode_latency=:.2f}ms, {gen_latency=:.2f}ms, {decode_latency=:.2f}ms, {total_latency=:.2f}ms' + f'{encode_latency=:.2f}ms, {gen_latency=:.2f}ms, {decode_latency=:.2f}ms, {total_latency=:.2f}ms', ) print(f'{latency_per_output_token=:.2f}ms/tok') print(f'{output_tok_per_sec=:.2f}tok/sec') diff --git a/scripts/inference/run_mpt_with_ft.py b/scripts/inference/run_mpt_with_ft.py index 61d9f68d2c..3361422466 100644 --- a/scripts/inference/run_mpt_with_ft.py +++ b/scripts/inference/run_mpt_with_ft.py @@ -41,141 +41,186 @@ @torch.no_grad() def main(): parser = argparse.ArgumentParser() - parser.add_argument('--layer_num', - type=int, - default=32, - help='number of layers') - parser.add_argument('--input_len', - type=int, - default=128, - help='input sequence length to generate.') - parser.add_argument('--output_len', - type=int, - default=64, - help='output sequence length to generate.') + parser.add_argument( + '--layer_num', + type=int, + default=32, + help='number of layers', + ) + parser.add_argument( + '--input_len', + type=int, + default=128, + help='input sequence length to generate.', + ) + parser.add_argument( + '--output_len', + type=int, + default=64, + help='output sequence length to generate.', + ) parser.add_argument('--head_num', type=int, default=32, help='head number') - parser.add_argument('--size_per_head', - type=int, - default=128, - help='size per head') - parser.add_argument('--vocab_size', - type=int, - default=50432, - help='vocab size') + parser.add_argument( + '--size_per_head', + type=int, + default=128, + help='size per head', + ) + parser.add_argument( + '--vocab_size', + type=int, + default=50432, + help='vocab size', + ) parser.add_argument( '--beam_width', type=int, default=1, - help='beam width for beam search. Using sampling when beam width is 1.') - parser.add_argument('--top_k', - type=int, - default=1, - help='top k candidate num') - parser.add_argument('--top_p', - type=float, - default=0.95, - help='top p probability threshold') - parser.add_argument('--temperature', - type=float, - default=0.8, - help='temperature') - parser.add_argument('--len_penalty', - type=float, - default=0., - help='len_penalty') - parser.add_argument('--beam_search_diversity_rate', - type=float, - default=0., - help='beam_search_diversity_rate') - parser.add_argument('--tensor_para_size', - type=int, - default=1, - help='tensor parallel size') - parser.add_argument('--pipeline_para_size', - type=int, - default=1, - help='pipeline parallel size') - parser.add_argument('--ckpt_path', - type=str, - default='mpt-ft-7b/1-gpu', - help='path to the FT checkpoint file.') + help='beam width for beam search. Using sampling when beam width is 1.', + ) + parser.add_argument( + '--top_k', + type=int, + default=1, + help='top k candidate num', + ) + parser.add_argument( + '--top_p', + type=float, + default=0.95, + help='top p probability threshold', + ) + parser.add_argument( + '--temperature', + type=float, + default=0.8, + help='temperature', + ) + parser.add_argument( + '--len_penalty', + type=float, + default=0., + help='len_penalty', + ) + parser.add_argument( + '--beam_search_diversity_rate', + type=float, + default=0., + help='beam_search_diversity_rate', + ) + parser.add_argument( + '--tensor_para_size', + type=int, + default=1, + help='tensor parallel size', + ) + parser.add_argument( + '--pipeline_para_size', + type=int, + default=1, + help='pipeline parallel size', + ) + parser.add_argument( + '--ckpt_path', + type=str, + default='mpt-ft-7b/1-gpu', + help='path to the FT checkpoint file.', + ) parser.add_argument( '--tokenizer_name_or_path', type=str, default='EleutherAI/gpt-neox-20b', help= - 'Name of the tokenizer or the directory where the tokenizer file is located.' + 'Name of the tokenizer or the directory where the tokenizer file is located.', ) parser.add_argument( '--lib_path', type=str, help= - 'path to the libth_transformer dynamic lib file(.e.g., build/lib/libth_transformer.so.' + 'path to the libth_transformer dynamic lib file(.e.g., build/lib/libth_transformer.so.', + ) + parser.add_argument( + '--start_id', + type=int, + default=0, + help='start token id.', ) - parser.add_argument('--start_id', - type=int, - default=0, - help='start token id.') parser.add_argument('--end_id', type=int, default=0, help='end token id.') parser.add_argument( '--max_batch_size', type=int, default=8, help= - 'Max batch size. If sample_input_file is given, it is truncated to this max_batch_size, otherwise, this value is used as batch size.' + 'Max batch size. If sample_input_file is given, it is truncated to this max_batch_size, otherwise, this value is used as batch size.', + ) + parser.add_argument( + '--repetition_penalty', + type=float, + default=5., + help='repetition penalty', ) - parser.add_argument('--repetition_penalty', - type=float, - default=5., - help='repetition penalty') parser.add_argument( '--presence_penalty', type=float, default=0., help= - 'presence penalty. Similar to repetition, but additive rather than multiplicative.' + 'presence penalty. Similar to repetition, but additive rather than multiplicative.', + ) + parser.add_argument( + '--min_length', + type=int, + default=0, + help='A minimum number of tokens to generate', ) - parser.add_argument('--min_length', - type=int, - default=0, - help='A minimum number of tokens to generate') parser.add_argument( '--max_seq_len', type=int, default=2048, - help='max sequence length for position embedding table.') - parser.add_argument('--inference_data_type', - '--data_type', - type=str, - choices=['fp32', 'fp16', 'bf16'], - default='bf16') - parser.add_argument('--time', - action='store_true', - help='whether or not to measure time elapsed.') + help='max sequence length for position embedding table.', + ) + parser.add_argument( + '--inference_data_type', + '--data_type', + type=str, + choices=['fp32', 'fp16', 'bf16'], + default='bf16', + ) + parser.add_argument( + '--time', + action='store_true', + help='whether or not to measure time elapsed.', + ) parser.add_argument( '--sample_input_file', type=str, default=None, help= - 'path to sample input file. If not set, it runs with no context inputs.' + 'path to sample input file. If not set, it runs with no context inputs.', + ) + parser.add_argument( + '--sample_output_file', + type=str, + default=None, + help='path to sample output file.', ) - parser.add_argument('--sample_output_file', - type=str, - default=None, - help='path to sample output file.') parser.add_argument( '--disable_random_seed', dest='random_seed', action='store_false', - help='Disable the use of random seed for sentences in a batch.') - parser.add_argument('--skip_end_tokens', - dest='skip_end_tokens', - action='store_false', - help='Whether to remove or not end tokens in outputs.') - parser.add_argument('--no_detokenize', - dest='detokenize', - action='store_false', - help='Skip detokenizing output token ids.') + help='Disable the use of random seed for sentences in a batch.', + ) + parser.add_argument( + '--skip_end_tokens', + dest='skip_end_tokens', + action='store_false', + help='Whether to remove or not end tokens in outputs.', + ) + parser.add_argument( + '--no_detokenize', + dest='detokenize', + action='store_false', + help='Skip detokenizing output token ids.', + ) parser.add_argument( '--int8_mode', type=int, @@ -183,7 +228,7 @@ def main(): choices=[0, 1], help='The level of quantization to perform.' + ' 0: No quantization. All computation in data_type' + - ' 1: Quantize weights to int8, all compute occurs in fp16/bf16. Not supported when data_type is fp32' + ' 1: Quantize weights to int8, all compute occurs in fp16/bf16. Not supported when data_type is fp32', ) parser.add_argument( '--weights_data_type', @@ -200,28 +245,33 @@ def main(): help='Whether to compute the cumulative log probability of sentences.' + ' 0: do not return the cumulative log probs' + ' 1: return the cumulative log probs of generated sequences' + - ' 2: return the cumulative log probs of sequences') - parser.add_argument('--shared_contexts_ratio', - type=float, - default=0.0, - help='Triggers the shared context optimization when ' + - 'compact_size <= shared_contexts_ratio * batch_size ' + - 'A value of 0.0 deactivate the optimization') + ' 2: return the cumulative log probs of sequences', + ) + parser.add_argument( + '--shared_contexts_ratio', + type=float, + default=0.0, + help='Triggers the shared context optimization when ' + + 'compact_size <= shared_contexts_ratio * batch_size ' + + 'A value of 0.0 deactivate the optimization', + ) parser.add_argument( '--use_gpt_decoder_ops', action='store_true', - help='Use separate decoder FT operators instead of end-to-end model op.' + help='Use separate decoder FT operators instead of end-to-end model op.', ) parser.add_argument( '--no-alibi', dest='alibi', action='store_false', - help='Do not use ALiBi (aka use_attention_linear_bias).') + help='Do not use ALiBi (aka use_attention_linear_bias).', + ) parser.add_argument( '--layernorm_eps', type=float, default=1e-5, - help='layernorm eps in PyTorch, by default, is 1e-5 and 1e-6 in FT.') + help='layernorm eps in PyTorch, by default, is 1e-5 and 1e-6 in FT.', + ) args = parser.parse_args() ckpt_config = configparser.ConfigParser() @@ -241,7 +291,8 @@ def main(): args.__dict__[args_key] = func('gpt', config_key) print( 'Loading {} from config.ini, previous: {}, current: {}' - .format(args_key, prev_val, args.__dict__[args_key])) + .format(args_key, prev_val, args.__dict__[args_key]), + ) else: print('Not loading {} from config.ini'.format(args_key)) for key in ['head_num', 'size_per_head', 'tensor_para_size']: @@ -250,7 +301,8 @@ def main(): args.__dict__[key] = ckpt_config.getint('gpt', key) print( 'Loading {} from config.ini, previous: {}, current: {}' - .format(key, prev_val, args.__dict__[key])) + .format(key, prev_val, args.__dict__[key]), + ) else: print('Not loading {} from config.ini'.format(key)) @@ -290,8 +342,10 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) torch.manual_seed(0) - comm.initialize_model_parallel(args.tensor_para_size, - args.pipeline_para_size) + comm.initialize_model_parallel( + args.tensor_para_size, + args.pipeline_para_size, + ) rank = comm.get_rank() device = comm.get_device() @@ -309,8 +363,9 @@ def main(): else: batch_size = max_batch_size contexts = ['<|endoftext|>'] * batch_size - start_ids = [torch.IntTensor([end_id for _ in range(args.input_len)]) - ] * batch_size + start_ids = [ + torch.IntTensor([end_id for _ in range(args.input_len)]), + ] * batch_size start_lengths = [len(ids) for ids in start_ids] @@ -319,54 +374,64 @@ def main(): # Prepare model. if not args.use_gpt_decoder_ops: - gpt = ParallelGPT(head_num, - size_per_head, - vocab_size, - start_id, - end_id, - layer_num, - max_seq_len, - tensor_para_size, - pipeline_para_size, - lib_path=args.lib_path, - inference_data_type=args.inference_data_type, - int8_mode=args.int8_mode, - weights_data_type=weights_data_type, - layernorm_eps=layernorm_eps, - use_attention_linear_bias=use_attention_linear_bias, - has_positional_encoding=has_positional_encoding, - shared_contexts_ratio=shared_contexts_ratio) + gpt = ParallelGPT( + head_num, + size_per_head, + vocab_size, + start_id, + end_id, + layer_num, + max_seq_len, + tensor_para_size, + pipeline_para_size, + lib_path=args.lib_path, + inference_data_type=args.inference_data_type, + int8_mode=args.int8_mode, + weights_data_type=weights_data_type, + layernorm_eps=layernorm_eps, + use_attention_linear_bias=use_attention_linear_bias, + has_positional_encoding=has_positional_encoding, + shared_contexts_ratio=shared_contexts_ratio, + ) if not gpt.load(ckpt_path=args.ckpt_path): print( - '[WARNING] Checkpoint file not found. Model loading is skipped.' + '[WARNING] Checkpoint file not found. Model loading is skipped.', ) else: - gpt = gpt_decoder.Gpt(num_heads=head_num, - size_per_head=size_per_head, - num_layers=layer_num, - vocab_size=vocab_size, - start_id=start_id, - end_id=end_id, - tensor_para_size=tensor_para_size, - pipeline_para_size=pipeline_para_size, - lib_path=args.lib_path, - max_seq_len=max_seq_len, - int8_mode=args.int8_mode, - weights_data_type=args.weights_data_type) + gpt = gpt_decoder.Gpt( + num_heads=head_num, + size_per_head=size_per_head, + num_layers=layer_num, + vocab_size=vocab_size, + start_id=start_id, + end_id=end_id, + tensor_para_size=tensor_para_size, + pipeline_para_size=pipeline_para_size, + lib_path=args.lib_path, + max_seq_len=max_seq_len, + int8_mode=args.int8_mode, + weights_data_type=args.weights_data_type, + ) gpt.load(args.ckpt_path, args.inference_data_type) if args.random_seed: - random_seed_tensor = torch.randint(0, - 10000, - size=[batch_size], - dtype=torch.int64) + random_seed_tensor = torch.randint( + 0, + 10000, + size=[batch_size], + dtype=torch.int64, + ) else: random_seed_tensor = torch.zeros([batch_size], dtype=torch.int64) repetition_penalty_vec = None if repetition_penalty == 1. else repetition_penalty * torch.ones( - batch_size, dtype=torch.float32) + batch_size, + dtype=torch.float32, + ) presence_penalty_vec = None if presence_penalty == 0. else presence_penalty * torch.ones( - batch_size, dtype=torch.float32) + batch_size, + dtype=torch.float32, + ) infer_decode_args = { 'beam_width': @@ -391,18 +456,20 @@ def main(): 'min_length': min_length * torch.ones(size=[batch_size], dtype=torch.int32), 'random_seed': - random_seed_tensor + random_seed_tensor, } if not args.use_gpt_decoder_ops: def gpt_generate_fn(): - tokens_batch = gpt(start_ids, - start_lengths, - output_len, - return_output_length=return_output_length, - return_cum_log_probs=return_cum_log_probs, - **infer_decode_args) + tokens_batch = gpt( + start_ids, + start_lengths, + output_len, + return_output_length=return_output_length, + return_cum_log_probs=return_cum_log_probs, + **infer_decode_args, + ) return tokens_batch else: @@ -414,7 +481,8 @@ def gpt_generate_fn(): eos_token_id=end_id, return_output_length=return_output_length, return_log_probs=return_cum_log_probs, - **infer_decode_args) + **infer_decode_args, + ) return output_dict # Generate tokens. @@ -428,8 +496,8 @@ def gpt_generate_fn(): tokens_batch, cum_log_probs = gen_outputs, None else: tokens_batch = gen_outputs['output_token_ids'] - cum_log_probs = gen_outputs[ - 'cum_log_probs'] if return_cum_log_probs > 0 else None + cum_log_probs = gen_outputs['cum_log_probs' + ] if return_cum_log_probs > 0 else None if cum_log_probs is not None: print('[INFO] Log probs of sentences:', cum_log_probs) @@ -442,11 +510,13 @@ def gpt_generate_fn(): if args.skip_end_tokens: token = token[token != end_id] output = tokenizer.decode( - token) if args.detokenize else ' '.join( - str(t) for t in token.tolist()) + token, + ) if args.detokenize else ' '.join( + str(t) for t in token.tolist() + ) outputs.append(output) print( - f'[INFO] batch {i}, beam {beam_id}:\n[Context]\n{context}\n\n[Output]\n{output}\n' + f'[INFO] batch {i}, beam {beam_id}:\n[Context]\n{context}\n\n[Output]\n{output}\n', ) if args.sample_output_file: @@ -469,10 +539,10 @@ def gpt_generate_fn(): if rank == 0: print(f'[INFO] MPT time costs:') print( - 'model_name, gpu_type, gpu_count, batch_size, input_tokens, output_tokens, latency_ms' + 'model_name, gpu_type, gpu_count, batch_size, input_tokens, output_tokens, latency_ms', ) print( - f'{ckpt_config.get("gpt", "model_name")}, {torch.cuda.get_device_name().replace(" ", "-")}, {torch.cuda.device_count()}, {batch_size}, {args.input_len}, {args.output_len}, {time_elapsed * 1000 / measurement_iterations:.2f}' + f'{ckpt_config.get("gpt", "model_name")}, {torch.cuda.get_device_name().replace(" ", "-")}, {torch.cuda.device_count()}, {batch_size}, {args.input_len}, {args.output_len}, {time_elapsed * 1000 / measurement_iterations:.2f}', ) diff --git a/scripts/misc/convert_examples_ckpt.py b/scripts/misc/convert_examples_ckpt.py index db1301674c..01da91b3f4 100644 --- a/scripts/misc/convert_examples_ckpt.py +++ b/scripts/misc/convert_examples_ckpt.py @@ -10,11 +10,17 @@ from typing import Any, Dict, Optional, Union import torch -from composer.utils import (get_file, maybe_create_object_store_from_uri, - parse_uri, safe_torch_load) - -from llmfoundry.models.mpt.configuration_mpt import (attn_config_defaults, - init_config_defaults) +from composer.utils import ( + get_file, + maybe_create_object_store_from_uri, + parse_uri, + safe_torch_load, +) + +from llmfoundry.models.mpt.configuration_mpt import ( + attn_config_defaults, + init_config_defaults, +) # define state dict key changes # old_state_dict_key: new_state_dict_key @@ -120,13 +126,21 @@ def convert_examples_ckpt( hf_config['attn_config'] = deepcopy(attn_config_defaults) hf_config['attn_config']['attn_type'] = 'multihead_attention' hf_config['attn_config']['qk_ln'] = hf_config.pop( - 'attn_qk_ln', attn_config_defaults['qk_ln']) + 'attn_qk_ln', + attn_config_defaults['qk_ln'], + ) hf_config['attn_config']['clip_qkv'] = hf_config.pop( - 'attn_clip_qkv', attn_config_defaults['clip_qkv']) + 'attn_clip_qkv', + attn_config_defaults['clip_qkv'], + ) for k in [ - 'attn_pdrop', 'attn_impl', 'softmax_scale', - 'attn_uses_sequence_id', 'alibi', 'alibi_bias_max' + 'attn_pdrop', + 'attn_impl', + 'softmax_scale', + 'attn_uses_sequence_id', + 'alibi', + 'alibi_bias_max', ]: if k in hf_config: hf_config['attn_config'][k] = hf_config.pop(k) @@ -144,9 +158,13 @@ def convert_examples_ckpt( hf_config['init_config']['name'] = hf_config.pop('param_init_fn') for k in [ - 'fan_mode', 'init_nonlinearity', 'init_gain', 'init_std', - 'init_div_is_residual', 'emb_init_std', - 'emb_init_uniform_lim' + 'fan_mode', + 'init_nonlinearity', + 'init_gain', + 'init_std', + 'init_div_is_residual', + 'emb_init_std', + 'emb_init_uniform_lim', ]: if k in hf_config: hf_config['init_config'][k] = hf_config.pop(k) @@ -167,28 +185,37 @@ def convert_examples_ckpt( composer_state_dict['state']['optimizers'][opt]['state'] = opt_state for pg_idx in range( - len(composer_state_dict['state']['optimizers'][opt] - ['param_groups'])): + len( + composer_state_dict['state']['optimizers'][opt] + ['param_groups'], + ), + ): for param_idx in range( - len(composer_state_dict['state']['optimizers'][opt] - ['param_groups'][pg_idx]['params'])): + len( + composer_state_dict['state']['optimizers'][opt] + ['param_groups'][pg_idx]['params'], + ), + ): param_name = composer_state_dict['state']['optimizers'][ opt]['param_groups'][pg_idx]['params'][param_idx] for old, new in conversion_dict.items(): param_name = param_name.replace(old, new) composer_state_dict['state']['optimizers'][opt][ - 'param_groups'][pg_idx]['params'][ - param_idx] = param_name + 'param_groups'][pg_idx]['params'][param_idx + ] = param_name # Save weights file_path = str( - Path(local_output_path) / str(checkpoint_path).split('/')[-1]) + Path(local_output_path) / str(checkpoint_path).split('/')[-1], + ) print(f'Writing converted output to {file_path}') torch.save(composer_state_dict, file_path) if object_store is not None: - remote_file_path = os.path.join(local_folder_path, - str(checkpoint_path).split('/')[-1]) + remote_file_path = os.path.join( + local_folder_path, + str(checkpoint_path).split('/')[-1], + ) print(f'Uploading from {file_path} to {remote_file_path}') object_store.upload_object(remote_file_path, file_path) @@ -205,7 +232,7 @@ def main(args: Namespace) -> None: if __name__ == '__main__': parser = ArgumentParser( description= - 'Convert ckpt created with the examples repo into one usable by llmfoundry.' + 'Convert ckpt created with the examples repo into one usable by llmfoundry.', ) parser.add_argument('--checkpoint_path', type=str, required=True) parser.add_argument('--output_path', type=str, required=True) diff --git a/scripts/misc/download_model.py b/scripts/misc/download_model.py index 539c185b26..4e36c35e29 100644 --- a/scripts/misc/download_model.py +++ b/scripts/misc/download_model.py @@ -22,21 +22,28 @@ import os from llmfoundry.utils.model_download_utils import ( - download_from_hf_hub, download_from_http_fileserver, download_from_oras) + download_from_hf_hub, + download_from_http_fileserver, + download_from_oras, +) HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' -logging.basicConfig(format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', - level=logging.INFO) +logging.basicConfig( + format=f'%(asctime)s: %(levelname)s: %(name)s: %(message)s', + level=logging.INFO, +) log = logging.getLogger(__name__) def add_hf_parser_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument('--model', type=str, required=True) parser.add_argument('--prefer-safetensors', type=bool, default=True) - parser.add_argument('--token', - type=str, - default=os.getenv(HF_TOKEN_ENV_VAR)) + parser.add_argument( + '--token', + type=str, + default=os.getenv(HF_TOKEN_ENV_VAR), + ) def add_oras_parser_arguments(parser: argparse.ArgumentParser) -> None: @@ -57,9 +64,11 @@ def parse_args() -> argparse.Namespace: base_parser = argparse.ArgumentParser(add_help=False) base_parser.add_argument('--save-dir', type=str, required=True) - base_parser.add_argument('--tokenizer-only', - default=False, - action='store_true') + base_parser.add_argument( + '--tokenizer-only', + default=False, + action='store_true', + ) # Add subparser for downloading from Hugging Face Hub. hf_parser = subparsers.add_parser('hf', parents=[base_parser]) @@ -91,11 +100,14 @@ def parse_args() -> argparse.Namespace: if download_from == 'http': if args.tokenizer_only: log.warning( - 'tokenizer-only is not currently supported for http. Downloading all files instead.' + 'tokenizer-only is not currently supported for http. Downloading all files instead.', ) try: - download_from_http_fileserver(args.url, args.save_dir, - args.ignore_cert) + download_from_http_fileserver( + args.url, + args.save_dir, + args.ignore_cert, + ) except PermissionError as e: log.error(f'Not authorized to download {args.model}.') raise e @@ -109,20 +121,25 @@ def parse_args() -> argparse.Namespace: download_from = 'oras' else: raise ValueError( - f'Invalid fallback destination {args.fallback}.') + f'Invalid fallback destination {args.fallback}.', + ) else: raise e if download_from == 'hf': - download_from_hf_hub(args.model, - save_dir=args.save_dir, - token=args.token, - tokenizer_only=args.tokenizer_only, - prefer_safetensors=args.prefer_safetensors) + download_from_hf_hub( + args.model, + save_dir=args.save_dir, + token=args.token, + tokenizer_only=args.tokenizer_only, + prefer_safetensors=args.prefer_safetensors, + ) elif download_from == 'oras': - download_from_oras(args.model, - args.config_file, - args.credentials_dir, - args.save_dir, - tokenizer_only=args.tokenizer_only, - concurrency=args.concurrency) + download_from_oras( + args.model, + args.config_file, + args.credentials_dir, + args.save_dir, + tokenizer_only=args.tokenizer_only, + concurrency=args.concurrency, + ) diff --git a/scripts/misc/profile_packing.py b/scripts/misc/profile_packing.py index fff10d158b..6bd048fd97 100644 --- a/scripts/misc/profile_packing.py +++ b/scripts/misc/profile_packing.py @@ -18,38 +18,46 @@ def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( description= - 'Profile packing_ratio choices for a particular workload.') + 'Profile packing_ratio choices for a particular workload.', + ) parser.add_argument( '--yaml-path', type=str, required=True, - help='Path to the YAML that defines the workload to profile.') - parser.add_argument('--num-devices', - type=int, - required=True, - help='How many devices your run will use.') - parser.add_argument('--min', - type=float, - required=True, - help='Smallest packing_ratio to test. Must be >=1.') + help='Path to the YAML that defines the workload to profile.', + ) + parser.add_argument( + '--num-devices', + type=int, + required=True, + help='How many devices your run will use.', + ) + parser.add_argument( + '--min', + type=float, + required=True, + help='Smallest packing_ratio to test. Must be >=1.', + ) parser.add_argument( '--max', type=float, required=True, - help='Largest packing_ratio to test. Must be larger than `min`.') + help='Largest packing_ratio to test. Must be larger than `min`.', + ) parser.add_argument( '--num-packing-ratios', type=int, default=20, help= - 'Number of packing_ratio values (spaced between `min` and `max) to try.' + 'Number of packing_ratio values (spaced between `min` and `max) to try.', ) args = parser.parse_args() if not os.path.isfile(args.yaml_path): raise FileNotFoundError( - '`yaml_path` does not correspond to any existing file.') + '`yaml_path` does not correspond to any existing file.', + ) if args.num_devices < 1: raise ValueError('`num_devices` must be a positive integer.') if args.min < 1.0: @@ -81,15 +89,22 @@ def parse_args() -> Namespace: resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True) if not isinstance(resolved_tokenizer_cfg, Dict): raise ValueError( - 'tokenizer config needs to be resolved by omegaconf into a Dict.') + 'tokenizer config needs to be resolved by omegaconf into a Dict.', + ) tokenizer_cfg = resolved_tokenizer_cfg tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - results = profile_packing(dataloader_cfg, tokenizer, args.min, args.max, - args.num_packing_ratios, device_batch_size) + results = profile_packing( + dataloader_cfg, + tokenizer, + args.min, + args.max, + args.num_packing_ratios, + device_batch_size, + ) header = '\n\n\n packing_ratio | % PADDING | % WASTE' fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%' diff --git a/scripts/misc/update_hub_code.py b/scripts/misc/update_hub_code.py index ee5f6935a3..20bb92fd04 100644 --- a/scripts/misc/update_hub_code.py +++ b/scripts/misc/update_hub_code.py @@ -56,14 +56,22 @@ def main(hf_repos_for_upload: List[str]): for repo in hf_repos_for_upload: print(f'Testing code changes for {repo}') pr_model = transformers.AutoModelForCausalLM.from_pretrained( - original_save_dir, trust_remote_code=True, device_map='auto') + original_save_dir, + trust_remote_code=True, + device_map='auto', + ) pr_tokenizer = transformers.AutoTokenizer.from_pretrained( - repo, trust_remote_code=True) + repo, + trust_remote_code=True, + ) - generation = pr_model.generate(pr_tokenizer( - 'MosaicML is', return_tensors='pt').input_ids.to( - 'cuda' if torch.cuda.is_available() else 'cpu'), - max_new_tokens=2) + generation = pr_model.generate( + pr_tokenizer( + 'MosaicML is', + return_tensors='pt', + ).input_ids.to('cuda' if torch.cuda.is_available() else 'cpu'), + max_new_tokens=2, + ) _ = pr_tokenizer.batch_decode(generation) print(f'Opening PR against {repo}') @@ -83,12 +91,14 @@ def main(hf_repos_for_upload: List[str]): if __name__ == '__main__': parser = argparse.ArgumentParser( description= - 'Update MPT code in HuggingFace Hub repos to be in sync with the local codebase' + 'Update MPT code in HuggingFace Hub repos to be in sync with the local codebase', + ) + parser.add_argument( + '--hf_repos_for_upload', + help='List of repos to open PRs against', + nargs='+', + required=True, ) - parser.add_argument('--hf_repos_for_upload', - help='List of repos to open PRs against', - nargs='+', - required=True) args = parser.parse_args() diff --git a/scripts/train/benchmarking/collect_results.py b/scripts/train/benchmarking/collect_results.py index 151286dbc6..26788788a2 100644 --- a/scripts/train/benchmarking/collect_results.py +++ b/scripts/train/benchmarking/collect_results.py @@ -24,24 +24,30 @@ def str_to_bool(value: Union[bool, str]): def parse_args(): - parser = argparse.ArgumentParser(description=""" + parser = argparse.ArgumentParser( + description=""" Parse run configs to get MPT training throughput. MFU and HFU are defined in https://arxiv.org/abs/2205.05198 All FLOP calculations do not include norm, act, residual, etc. - """) + """, + ) parser.add_argument('--project', type=str, default='tput') parser.add_argument('--filters', type=str, default=[], nargs='+') - parser.add_argument('-s', - '--save-path', - type=str, - default='benchmark_results') - parser.add_argument('-p', - '--print-results', - type=str_to_bool, - nargs='?', - const=True, - default=False) + parser.add_argument( + '-s', + '--save-path', + type=str, + default='benchmark_results', + ) + parser.add_argument( + '-p', + '--print-results', + type=str_to_bool, + nargs='?', + const=True, + default=False, + ) return parser.parse_args() @@ -68,9 +74,15 @@ def sort_key(r: msdk.Run): print(model_name) raise ValueError model_size = int(model_name[:-1]) - return (gpu_type, model_precision, model_name_size, model_size, - r.submitted_config.parameters['max_seq_len'], num_gpu, - r.submitted_config.parameters['global_train_batch_size']) + return ( + gpu_type, + model_precision, + model_name_size, + model_size, + r.submitted_config.parameters['max_seq_len'], + num_gpu, + r.submitted_config.parameters['global_train_batch_size'], + ) unique_runs = {sort_key(i): i for i in runs} runs = [unique_runs[r] for r in unique_runs] @@ -84,7 +96,7 @@ def filter_runs(runs: List[msdk.Run]): for run in runs: if run.status == msdk.RunStatus('FAILED'): print( - f"run {run.name} has FAILED (likely due to OOM error but we'd recommend checking.)" + f"run {run.name} has FAILED (likely due to OOM error but we'd recommend checking.)", ) pop_runs.append(run) @@ -161,8 +173,8 @@ def parse_run(run: msdk.Run) -> Dict[str, Any]: # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2) attn_flops_per_seq = n_layers * 2 * 2 * (d_model * (seq_len**2)) # there are 2 ops in bwd pass and 1 in fwd pass so we mult by 3 - mfu_w_attn = (3 * flops_per_seq + 3 * attn_flops_per_seq) * throughput / ( - gpus * GPU_AVAILABLE_FLOPS) + mfu_w_attn = (3 * flops_per_seq + 3 * attn_flops_per_seq + ) * throughput / (gpus * GPU_AVAILABLE_FLOPS) if activation_checkpointing: hfu_w_attn = (4 * flops_per_seq + 4 * attn_flops_per_seq @@ -170,8 +182,8 @@ def parse_run(run: msdk.Run) -> Dict[str, Any]: else: hfu_w_attn = mfu_w_attn - model_tflop = int( - (3 * flops_per_seq + 3 * attn_flops_per_seq) * throughput / gpus / 1e12) + model_tflop = int((3 * flops_per_seq + 3 * attn_flops_per_seq) * + throughput / gpus / 1e12,) return { 'Model': @@ -225,7 +237,7 @@ def main(args: argparse.Namespace): for run in runs: try: results.append(parse_run(run)) - except Exception as e: + except Exception as e: # noqa: PERF203 print(f'{run.name=} not parsed') print(e) diff --git a/scripts/train/benchmarking/submit_benchmarks.py b/scripts/train/benchmarking/submit_benchmarks.py index 5e83ae41b7..7e5bae7afc 100644 --- a/scripts/train/benchmarking/submit_benchmarks.py +++ b/scripts/train/benchmarking/submit_benchmarks.py @@ -40,69 +40,91 @@ def str_to_bool(value: Union[bool, str]): def parse_args(): parser = argparse.ArgumentParser( description= - 'Generate and run configurations to test MPT training throughput on the MosaicML platform.' + 'Generate and run configurations to test MPT training throughput on the MosaicML platform.', ) parser.add_argument('--project', type=str, default='tput') parser.add_argument( '--image', type=str, - default='mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04') - parser.add_argument('--git_branch', - type=str, - default=None, - help='what git branch to use.') - parser.add_argument('--git_commit', - type=str, - default=None, - help='what git commit to use.') - parser.add_argument('-t', - '--precisions', - '--types', - type=str, - default=['bf16'], - nargs='+', - choices=['bf16', 'fp16', 'fp8']) - parser.add_argument('--fsdp_config_mixed_precision', - type=str, - default='PURE') - parser.add_argument('--fsdp_config_activation_checkpointing', - type=str_to_bool, - nargs='?', - const=True, - default=None) - parser.add_argument('--fsdp_config_shard_strategy', - type=str, - nargs='?', - const=True, - default=None) - parser.add_argument('--fsdp_config_limit_all_gathers', - type=str_to_bool, - nargs='?', - const=True, - default=None) - parser.add_argument('--fsdp_config_forward_prefetch', - type=str_to_bool, - nargs='?', - const=True, - default=None) - parser.add_argument('--fsdp_config_backward_prefetch', - type=str, - nargs='?', - const=True, - default=None) - parser.add_argument('--activation_cpu_offload', - type=str_to_bool, - nargs='?', - const=True, - default=None) + default='mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04', + ) + parser.add_argument( + '--git_branch', + type=str, + default=None, + help='what git branch to use.', + ) + parser.add_argument( + '--git_commit', + type=str, + default=None, + help='what git commit to use.', + ) + parser.add_argument( + '-t', + '--precisions', + '--types', + type=str, + default=['bf16'], + nargs='+', + choices=['bf16', 'fp16', 'fp8'], + ) + parser.add_argument( + '--fsdp_config_mixed_precision', + type=str, + default='PURE', + ) + parser.add_argument( + '--fsdp_config_activation_checkpointing', + type=str_to_bool, + nargs='?', + const=True, + default=None, + ) + parser.add_argument( + '--fsdp_config_shard_strategy', + type=str, + nargs='?', + const=True, + default=None, + ) + parser.add_argument( + '--fsdp_config_limit_all_gathers', + type=str_to_bool, + nargs='?', + const=True, + default=None, + ) + parser.add_argument( + '--fsdp_config_forward_prefetch', + type=str_to_bool, + nargs='?', + const=True, + default=None, + ) + parser.add_argument( + '--fsdp_config_backward_prefetch', + type=str, + nargs='?', + const=True, + default=None, + ) + parser.add_argument( + '--activation_cpu_offload', + type=str_to_bool, + nargs='?', + const=True, + default=None, + ) parser.add_argument( '-s', '--seq_len_exp', type=int, default=[11, 11], nargs=2, - help='exponent of seq lengths to be tested (default: [11, 11] = 2048)') + help='exponent of seq lengths to be tested (default: [11, 11] = 2048)', + ) parser.add_argument( '-b', '--batch_size_exp', @@ -110,7 +132,7 @@ def parse_args(): default=None, nargs=2, help= - 'exponent of batch size (in tokens) to be tested (default: [19, 23] = 2^19 to 2^23)' + 'exponent of batch size (in tokens) to be tested (default: [19, 23] = 2^19 to 2^23)', ) parser.add_argument( '--batch_sizes', @@ -125,71 +147,99 @@ def parse_args(): default=None, help='batch sizes multiplier (accumulations before step).', ) - parser.add_argument('-m', - '--model_yamls', - type=str, - default=[ - '125m.yaml', '350m.yaml', '760m.yaml', '1b.yaml', - '3b.yaml', '7b.yaml', '13b.yaml', '30b.yaml', - '70b.yaml' - ], - choices=[ - '125m.yaml', '350m.yaml', '760m.yaml', '1b.yaml', - '3b.yaml', '7b.yaml', '13b.yaml', '30b.yaml', - '70b.yaml' - ], - nargs='+', - help='model sizes to test') + parser.add_argument( + '-m', + '--model_yamls', + type=str, + default=[ + '125m.yaml', + '350m.yaml', + '760m.yaml', + '1b.yaml', + '3b.yaml', + '7b.yaml', + '13b.yaml', + '30b.yaml', + '70b.yaml', + ], + choices=[ + '125m.yaml', + '350m.yaml', + '760m.yaml', + '1b.yaml', + '3b.yaml', + '7b.yaml', + '13b.yaml', + '30b.yaml', + '70b.yaml', + ], + nargs='+', + help='model sizes to test', + ) parser.add_argument('--attn_impl', type=str, default='flash') - parser.add_argument('-c', - '--clusters', - type=str, - default=['r1z1'], - nargs='+', - choices=CLUSTER_INFO.keys()) + parser.add_argument( + '-c', + '--clusters', + type=str, + default=['r1z1'], + nargs='+', + choices=CLUSTER_INFO.keys(), + ) known_args = parser.parse_known_args()[0] _gpu_types = get_gpu_types(known_args.clusters) - parser.add_argument('--gpu_types', - type=str, - default=['a100_40gb'], - nargs='+', - choices=_gpu_types) + parser.add_argument( + '--gpu_types', + type=str, + default=['a100_40gb'], + nargs='+', + choices=_gpu_types, + ) known_args = parser.parse_known_args()[0] _gpu_nums = get_gpu_nums(known_args.clusters, known_args.gpu_types) - parser.add_argument('-g', - '--gpu_nums', - type=int, - default=[8], - nargs='+', - choices=_gpu_nums) - - parser.add_argument('--microbatch_size', - type=int, - default=None, - help='set microbatch_size') + parser.add_argument( + '-g', + '--gpu_nums', + type=int, + default=[8], + nargs='+', + choices=_gpu_nums, + ) + + parser.add_argument( + '--microbatch_size', + type=int, + default=None, + help='set microbatch_size', + ) parser.add_argument('--pad_vocab_multiple', type=int, default=None) - parser.add_argument('--data_remote', - type=str, - default=None, - help='optional data remote path for streaming data') + parser.add_argument( + '--data_remote', + type=str, + default=None, + help='optional data remote path for streaming data', + ) - parser.add_argument('--wandb', - type=str_to_bool, - nargs='?', - const=True, - default=True) + parser.add_argument( + '--wandb', + type=str_to_bool, + nargs='?', + const=True, + default=True, + ) parser.add_argument('--priority', type=str, default='lowest') - parser.add_argument('--RUN', - type=str_to_bool, - nargs='?', - const=True, - default=False) + parser.add_argument( + '--RUN', + type=str_to_bool, + nargs='?', + const=True, + default=False, + ) return parser.parse_args() @@ -199,16 +249,19 @@ def get_max_seq_lens(pows: Optional[List[int]] = None): return [2**n for n in range(pows[0], pows[1] + 1)] -def get_global_train_batch_sizes(max_seq_len: int, - pows: List[int], - batch_sizes: Optional[List[int]] = None): +def get_global_train_batch_sizes( + max_seq_len: int, + pows: List[int], + batch_sizes: Optional[List[int]] = None, +): if batch_sizes is None: batch_sizes = [] if pows: # global batch size in tokens (default: .5M thru 8M) global_train_token_counts = [2**n for n in range(pows[0], pows[1] + 1)] - batch_sizes += [t // max_seq_len for t in global_train_token_counts - ] # global batch size in samples + batch_sizes += [ + t // max_seq_len for t in global_train_token_counts + ] # global batch size in samples return batch_sizes @@ -290,8 +343,8 @@ def mod_parameters( 'data_remote'] parameters['data_local'] = '/tmp/c4' - parameters['train_loader']['dataset']['local'] = parameters[ - 'data_local'] + parameters['train_loader']['dataset']['local'] = parameters['data_local' + ] parameters['eval_loader']['dataset']['local'] = parameters['data_local'] else: parameters['train_loader']['dataset'][ @@ -310,7 +363,8 @@ def mod_parameters( if pad_vocab_multiple: vocab_size = parameters['model']['vocab_size'] parameters['model']['vocab_size'] = math.ceil( - vocab_size / pad_vocab_multiple) * pad_vocab_multiple + vocab_size / pad_vocab_multiple, + ) * pad_vocab_multiple parameters['tokenizer']['kwargs']['model_max_length'] = max_seq_len parameters['train_loader']['dataset']['max_seq_len'] = max_seq_len @@ -323,10 +377,11 @@ def mod_parameters( # update eval batch size based on change in seq len parameters['device_eval_batch_size'] = max( 1, - int(parameters['device_eval_batch_size'] / ((max_seq_len / 2048)**2))) + int(parameters['device_eval_batch_size'] / ((max_seq_len / 2048)**2)), + ) - parameters['eval_loader'][ - 'eval_subset_num_batches'] = 2 # for throughput testing purposes + parameters['eval_loader']['eval_subset_num_batches' + ] = 2 # for throughput testing purposes parameters['max_duration'] = max_duration parameters['eval_interval'] = eval_interval @@ -334,23 +389,23 @@ def mod_parameters( parameters['precision'] = precision parameters['fsdp_config']['mixed_precision'] = fsdp_config_mixed_precision if fsdp_config_activation_checkpointing is not None: - parameters['fsdp_config'][ - 'activation_checkpointing'] = fsdp_config_activation_checkpointing + parameters['fsdp_config']['activation_checkpointing' + ] = fsdp_config_activation_checkpointing if fsdp_config_shard_strategy is not None: - parameters['fsdp_config'][ - 'sharding_strategy'] = fsdp_config_shard_strategy + parameters['fsdp_config']['sharding_strategy' + ] = fsdp_config_shard_strategy if fsdp_config_limit_all_gathers is not None: - parameters['fsdp_config'][ - 'limit_all_gathers'] = fsdp_config_limit_all_gathers + parameters['fsdp_config']['limit_all_gathers' + ] = fsdp_config_limit_all_gathers if fsdp_config_forward_prefetch is not None: - parameters['fsdp_config'][ - 'forward_prefetch'] = fsdp_config_forward_prefetch + parameters['fsdp_config']['forward_prefetch' + ] = fsdp_config_forward_prefetch if fsdp_config_backward_prefetch is not None: - parameters['fsdp_config'][ - 'backward_prefetch'] = fsdp_config_backward_prefetch + parameters['fsdp_config']['backward_prefetch' + ] = fsdp_config_backward_prefetch if activation_cpu_offload is not None: - parameters['fsdp_config'][ - 'activation_cpu_offload'] = activation_cpu_offload + parameters['fsdp_config']['activation_cpu_offload' + ] = activation_cpu_offload if wandb: # add wandb @@ -359,10 +414,12 @@ def mod_parameters( return parameters -def get_integrations(project: str, - git_branch: Optional[str] = None, - git_commit: Optional[str] = None, - wandb: bool = True): +def get_integrations( + project: str, + git_branch: Optional[str] = None, + git_commit: Optional[str] = None, + wandb: bool = True, +): integrations = [] if git_branch and git_commit: @@ -376,7 +433,7 @@ def get_integrations(project: str, git_integration.update({ 'integration_type': 'git_repo', 'git_repo': 'mosaicml/llm-foundry', - 'pip_install': '.[gpu-flash2]' + 'pip_install': '.[gpu-flash2]', }) integrations = [git_integration] @@ -385,14 +442,16 @@ def get_integrations(project: str, integrations += [{ 'integration_type': 'wandb', 'entity': 'mosaic-ml', - 'project': project + 'project': project, }] return integrations -def run_config(config: Tuple[str, int, int, str, str, int, str], - args: argparse.Namespace): +def run_config( + config: Tuple[str, int, int, str, str, int, str], + args: argparse.Namespace, +): model_yaml, max_seq_len, global_train_batch_size, cluster, gpu_type, gpu_num, precision = config integrations = [ { @@ -404,7 +463,7 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], { 'integration_type': 'wandb', 'entity': 'mosaic-ml', - 'project': args.project + 'project': args.project, }, ] @@ -432,14 +491,17 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], path = os.path.join('../yamls/pretrain', 'mpt-' + model_yaml) parameters = get_parameters(path) - model_name = '-'.join(model_yaml.split('.')[-2].split('/')[-2:]).replace( - '_', '-') + model_name = '-'.join( + model_yaml.split('.')[-2].split('/')[-2:], + ).replace('_', '-') model_name = model_name.split('-') if 'mosaic' in model_name: model_name.pop(model_name.index('mosaic')) model_name = ''.join(model_name) name = f"{args.project}-{cluster}-{model_name}-{gpu_num}x{gpu_type}-s{max_seq_len}b{global_train_batch_size}{precision.replace('amp_', '')}".replace( - '_', '-') + '_', + '-', + ) name_len_lim = 54 - 7 if len(name) > name_len_lim: @@ -469,18 +531,19 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], if gpu_type == 'h100_80gb' and precision == 'fp8': parameters['model']['fc_type'] = 'te' # Create run config mcli sdk/api - config = RunConfig(name=name, - compute={ - 'cluster': cluster, - 'gpu_type': gpu_type, - 'gpus': gpu_num - }, - image=args.image, - integrations=integrations, - command=command, - parameters=parameters, - scheduling=SchedulingConfig(priority=args.priority, - resumable=True)) + config = RunConfig( + name=name, + compute={ + 'cluster': cluster, + 'gpu_type': gpu_type, + 'gpus': gpu_num, + }, + image=args.image, + integrations=integrations, + command=command, + parameters=parameters, + scheduling=SchedulingConfig(priority=args.priority, resumable=True), + ) if args.RUN: # Create the run from a config run = create_run(config) @@ -490,10 +553,12 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], print(f'{config=}') -def run_check_capacity(model_yaml: str, - gpu_num: int, - gpu_type: str, - p_multiplier: int = 16): +def run_check_capacity( + model_yaml: str, + gpu_num: int, + gpu_type: str, + p_multiplier: int = 16, +): _params = model_yaml.replace('.yaml', '') params, mult = int(_params[:-1]), _params[-1] if mult == 'm': @@ -507,7 +572,7 @@ def run_check_capacity(model_yaml: str, if p_multiplier * b_params > gpu_num * gpu_mem: print( - f'WARNING: will not be running {model_yaml=} on {gpu_num=} {gpu_type=} since it probably will not fit into memory' + f'WARNING: will not be running {model_yaml=} on {gpu_num=} {gpu_type=} since it probably will not fit into memory', ) return False return True @@ -516,7 +581,7 @@ def run_check_capacity(model_yaml: str, def run_check_dtms(num_gpus: int, dtms: int, batch_size: int): if num_gpus * dtms > batch_size: print( - f'WARNING: Cannot run with {batch_size=} on {num_gpus=} with {dtms=} ({num_gpus*dtms=}).' + f'WARNING: Cannot run with {batch_size=} on {num_gpus=} with {dtms=} ({num_gpus*dtms=}).', ) return False return True @@ -532,33 +597,44 @@ def run_check_dtms(num_gpus: int, dtms: int, batch_size: int): _gpu_nums = [ng for ng in args.gpu_nums if ng <= ng_lim] for gpu_num in _gpu_nums: global_train_batch_sizes = get_global_train_batch_sizes( - max_seq_len, args.batch_size_exp, args.batch_sizes) + max_seq_len, + args.batch_size_exp, + args.batch_sizes, + ) if not global_train_batch_sizes and args.microbatch_size is not None: accum = args.accum or 1 global_train_batch_sizes = [ - accum * gpu_num * args.microbatch_size + accum * gpu_num * args.microbatch_size, ] for global_train_batch_size in global_train_batch_sizes: for precision in args.precisions: for model_yaml in args.model_yamls: - run = run_check_capacity(model_yaml, - gpu_num, - gpu_type, - p_multiplier=4) + run = run_check_capacity( + model_yaml, + gpu_num, + gpu_type, + p_multiplier=4, + ) if args.microbatch_size is not None: run = run and run_check_dtms( - gpu_num, args.microbatch_size, - global_train_batch_size) + gpu_num, + args.microbatch_size, + global_train_batch_size, + ) if run: config: Tuple[str, int, int, str, str, int, str] = ( - model_yaml, max_seq_len, + model_yaml, + max_seq_len, global_train_batch_size, - cluster, gpu_type, - gpu_num, precision) + cluster, + gpu_type, + gpu_num, + precision, + ) run_config(config, args) n_jobs += 1 diff --git a/scripts/train/benchmarking/sweep.py b/scripts/train/benchmarking/sweep.py index 57b3aa262c..c113127464 100644 --- a/scripts/train/benchmarking/sweep.py +++ b/scripts/train/benchmarking/sweep.py @@ -85,7 +85,8 @@ # Iterate over the arguments and call submit_benchmarks.py for num_gpu_args in num_gpu_args_list: for model_args in model_args_list: - command = ['python submit_benchmarks.py' - ] + base_args + num_gpu_args + model_args + command = [ + 'python submit_benchmarks.py', + ] + base_args + num_gpu_args + model_args command = ' '.join(command) os.system(command) diff --git a/scripts/train/finetune_example/preprocessing.py b/scripts/train/finetune_example/preprocessing.py index adfa3c5cce..5f0639d22b 100644 --- a/scripts/train/finetune_example/preprocessing.py +++ b/scripts/train/finetune_example/preprocessing.py @@ -35,7 +35,8 @@ def multiple_choice( - inp: Dict[str, Union[str, List[str], int]]) -> Dict[str, str]: + inp: Dict[str, Union[str, List[str], int]], +) -> Dict[str, str]: PROMPT_FORMAT = '{query}\nOptions:{options}\nAnswer: ' options = '' assert isinstance(inp['choices'], List) diff --git a/scripts/train/train.py b/scripts/train/train.py index 65ac11f63c..63426755dc 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -12,29 +12,45 @@ import torch from composer import Trainer from composer.core.callback import Callback -from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler, - cyclic_schedule) +from composer.profiler import ( + JSONTraceHandler, + Profiler, + TraceHandler, + cyclic_schedule, +) from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from rich.traceback import install from llmfoundry.eval.metrics.nlp import InContextLearningMetric -from llmfoundry.utils import (find_mosaicml_logger, log_train_analytics, - maybe_create_mosaicml_logger) +from llmfoundry.utils import ( + find_mosaicml_logger, + log_train_analytics, + maybe_create_mosaicml_logger, +) install() from llmfoundry.callbacks import AsyncEval from llmfoundry.data.dataloader import build_dataloader from llmfoundry.layers_registry import ffns_with_megablocks -from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, - build_algorithm, build_callback, - build_composer_model, build_evaluators, - build_logger, build_optimizer, - build_scheduler, build_tokenizer) -from llmfoundry.utils.config_utils import (log_config, pop_config, - process_init_device, - update_batch_size_info) +from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, + build_algorithm, + build_callback, + build_composer_model, + build_evaluators, + build_logger, + build_optimizer, + build_scheduler, + build_tokenizer, +) +from llmfoundry.utils.config_utils import ( + log_config, + pop_config, + process_init_device, + update_batch_size_info, +) from llmfoundry.utils.registry_utils import import_file log = logging.getLogger(__name__) @@ -50,7 +66,8 @@ def validate_config(cfg: DictConfig): if loader.label is None: raise ValueError( 'When specifying multiple evaluation datasets, each one must include the \ - `label` attribute.') + `label` attribute.', + ) loaders.append(loader) else: loaders.append(eval_loader) @@ -64,53 +81,59 @@ def validate_config(cfg: DictConfig): if 'icl_tasks' in cfg: if cfg.model.name == 'hf_t5': raise ValueError( - 'ICL evaluation does not currently support Encoder-Decoder models, such as "hf_t5".' + 'ICL evaluation does not currently support Encoder-Decoder models, such as "hf_t5".', ) - if (cfg.model.get('fc_type', 'torch') != 'te' and 'te' not in cfg.model.get( - 'ffn_config', {}).get('ffn_type', 'mptmlp') and - 'fp8' in cfg.precision): + if ( + cfg.model.get('fc_type', 'torch') != 'te' and 'te' + not in cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') and + 'fp8' in cfg.precision + ): warnings.warn( "fp8 only supported for te.Linear layers. Either set `cfg.model.fc_typ='te'` or " + - "`cfg.model.ffn_config.ffn_type='te_ln_mlp'` to enable layers using fp8 precision." + "`cfg.model.ffn_config.ffn_type='te_ln_mlp'` to enable layers using fp8 precision.", ) - if (cfg.model.get('fc_type', 'torch') == 'te' or - 'te' in cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp')): + if ( + cfg.model.get('fc_type', 'torch') == 'te' or + 'te' in cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') + ): fsdp_config = cfg.get('fsdp_config', None) act_ckpt = fsdp_config.get('activation_checkpointing', False) act_ckpt_reentrant = fsdp_config.get( - 'activation_checkpointing_reentrant', False) + 'activation_checkpointing_reentrant', + False, + ) if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True: warnings.warn( '`te.Linear` layers do not support activation_checkpointing with ' + '`activation_checkpointing_reentrant = True`. ' + - 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.' + 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.', ) cfg.fsdp_config.activation_checkpointing_reentrant = False if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': warnings.warn( '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + - 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.' + 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.', ) torch._dynamo.config.suppress_errors = True # type: ignore (third-party) if cfg.model.get('load_in_8bit', False): raise ValueError( - '`load_in_8bit` is only supported for evaluation rather than training.' + '`load_in_8bit` is only supported for evaluation rather than training.', ) - if cfg.model.get('ffn_config', {}).get('ffn_type', - 'mptmlp') in ffns_with_megablocks: + if cfg.model.get('ffn_config', + {}).get('ffn_type', 'mptmlp') in ffns_with_megablocks: moe_world_size = cfg.model.get('ffn_config', {}).get('moe_world_size', 1) use_orig_params = cfg.get('fsdp_config', {}).get('use_orig_params', True) if moe_world_size > 1 and not use_orig_params: raise ValueError( - f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.' + f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.', ) @@ -129,11 +152,13 @@ def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): def main(cfg: DictConfig) -> Trainer: # Run user provided code if specified - code_paths = pop_config(cfg, - 'code_paths', - must_exist=False, - default_value=[], - convert=True) + code_paths = pop_config( + cfg, + 'code_paths', + must_exist=False, + default_value=[], + convert=True, + ) # Import any user provided code for code_path in code_paths: import_file(code_path) @@ -143,7 +168,7 @@ def main(cfg: DictConfig) -> Trainer: action='ignore', category=UserWarning, message= - 'torch.distributed.*_base is a private function and will be deprecated.*' + 'torch.distributed.*_base is a private function and will be deprecated.*', ) # Check for incompatibilities between the model and data loaders @@ -179,26 +204,32 @@ def main(cfg: DictConfig) -> Trainer: reproducibility.seed_all(seed) # Initialize pytorch distributed training process groups - dist_timeout: Union[int, float] = pop_config(cfg, - 'dist_timeout', - must_exist=False, - default_value=600.0) - python_log_level: Optional[str] = pop_config(cfg, - 'python_log_level', - must_exist=False, - default_value='debug') + dist_timeout: Union[int, float] = pop_config( + cfg, + 'dist_timeout', + must_exist=False, + default_value=600.0, + ) + python_log_level: Optional[str] = pop_config( + cfg, + 'python_log_level', + must_exist=False, + default_value='debug', + ) # Set logging level if python_log_level is not None: logging.basicConfig( # Example of format string # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here format= - f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', ) logging.getLogger('llmfoundry').setLevel( - python_log_level.upper()) # Foundry module + python_log_level.upper(), + ) # Foundry module logging.getLogger(__name__).setLevel( - python_log_level.upper()) # Train script + python_log_level.upper(), + ) # Train script _initialize_dist_with_barrier(dist_timeout=dist_timeout) @@ -208,176 +239,227 @@ def main(cfg: DictConfig) -> Trainer: # Mandatory model training configs model_config: DictConfig = pop_config(cfg, 'model', must_exist=True) - tokenizer_config: Dict[str, Any] = pop_config(cfg, - 'tokenizer', - must_exist=True, - convert=True) - optimizer_config: Dict[str, Any] = pop_config(cfg, - 'optimizer', - must_exist=True, - convert=True) - scheduler_config: Dict[str, Any] = pop_config(cfg, - 'scheduler', - must_exist=True, - convert=True) - train_loader_config: DictConfig = pop_config(cfg, - 'train_loader', - must_exist=True) + tokenizer_config: Dict[ + str, Any] = pop_config(cfg, 'tokenizer', must_exist=True, convert=True) + optimizer_config: Dict[ + str, Any] = pop_config(cfg, 'optimizer', must_exist=True, convert=True) + scheduler_config: Dict[ + str, Any] = pop_config(cfg, 'scheduler', must_exist=True, convert=True) + train_loader_config: DictConfig = pop_config( + cfg, + 'train_loader', + must_exist=True, + ) # Optional fsdp data, fine-tuning, and eval configs - fsdp_config: Optional[Dict[str, Any]] = pop_config(cfg, - 'fsdp_config', - must_exist=False, - default_value=None, - convert=True) - eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config( - cfg, 'eval_loader', must_exist=False, default_value=None) - icl_tasks_config: Optional[Union[ListConfig, - str]] = pop_config(cfg, - 'icl_tasks', - must_exist=False, - default_value=None) - eval_gauntlet_config: Optional[Union[DictConfig, - str]] = pop_config(cfg, - 'eval_gauntlet', - must_exist=False, - default_value=None) - icl_subset_num_batches: Optional[int] = pop_config(cfg, - 'icl_subset_num_batches', - must_exist=False, - default_value=None) - icl_seq_len: Optional[int] = pop_config(cfg, - 'icl_seq_len', - must_exist=False, - default_value=None) + fsdp_config: Optional[Dict[str, Any]] = pop_config( + cfg, + 'fsdp_config', + must_exist=False, + default_value=None, + convert=True, + ) + eval_loader_config: Optional[ + Union[DictConfig, ListConfig] + ] = pop_config(cfg, 'eval_loader', must_exist=False, default_value=None) + icl_tasks_config: Optional[ + Union[ListConfig, str] + ] = pop_config(cfg, 'icl_tasks', must_exist=False, default_value=None) + eval_gauntlet_config: Optional[ + Union[DictConfig, str] + ] = pop_config(cfg, 'eval_gauntlet', must_exist=False, default_value=None) + icl_subset_num_batches: Optional[int] = pop_config( + cfg, + 'icl_subset_num_batches', + must_exist=False, + default_value=None, + ) + icl_seq_len: Optional[int] = pop_config( + cfg, + 'icl_seq_len', + must_exist=False, + default_value=None, + ) # Optional logging, evaluation and callback configs - logger_configs: Optional[DictConfig] = pop_config(cfg, - 'loggers', - must_exist=False, - default_value=None, - convert=True) - callback_configs: Optional[DictConfig] = pop_config(cfg, - 'callbacks', - must_exist=False, - default_value=None, - convert=True) - algorithm_configs: Optional[DictConfig] = pop_config(cfg, - 'algorithms', - must_exist=False, - default_value=None) + logger_configs: Optional[DictConfig] = pop_config( + cfg, + 'loggers', + must_exist=False, + default_value=None, + convert=True, + ) + callback_configs: Optional[DictConfig] = pop_config( + cfg, + 'callbacks', + must_exist=False, + default_value=None, + convert=True, + ) + algorithm_configs: Optional[DictConfig] = pop_config( + cfg, + 'algorithms', + must_exist=False, + default_value=None, + ) # Mandatory hyperparameters for training - device_train_batch_size: int = pop_config(cfg, - 'device_train_batch_size', - must_exist=True) - device_eval_batch_size: int = pop_config(cfg, - 'device_eval_batch_size', - must_exist=True) - max_duration: Union[int, str] = pop_config(cfg, - 'max_duration', - must_exist=True) - eval_interval: Union[int, str] = pop_config(cfg, - 'eval_interval', - default_value=1, - must_exist=False) + device_train_batch_size: int = pop_config( + cfg, + 'device_train_batch_size', + must_exist=True, + ) + device_eval_batch_size: int = pop_config( + cfg, + 'device_eval_batch_size', + must_exist=True, + ) + max_duration: Union[int, + str] = pop_config(cfg, 'max_duration', must_exist=True) + eval_interval: Union[int, str] = pop_config( + cfg, + 'eval_interval', + default_value=1, + must_exist=False, + ) precision: str = pop_config(cfg, 'precision', must_exist=True) max_seq_len: int = pop_config(cfg, 'max_seq_len', must_exist=True) # Optional parameters will be set to default values if not specified. default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name: str = pop_config(cfg, - 'run_name', - must_exist=False, - default_value=default_run_name) - save_folder: Optional[str] = pop_config(cfg, - 'save_folder', - must_exist=False, - default_value=None) - is_state_dict_sharded: bool = (fsdp_config.get('state_dict_type', 'full') - == 'sharded') if fsdp_config else False + run_name: str = pop_config( + cfg, + 'run_name', + must_exist=False, + default_value=default_run_name, + ) + save_folder: Optional[str] = pop_config( + cfg, + 'save_folder', + must_exist=False, + default_value=None, + ) + is_state_dict_sharded: bool = ( + fsdp_config.get('state_dict_type', 'full') == 'sharded' + ) if fsdp_config else False save_latest_filename: str = pop_config( cfg, 'save_latest_filename', must_exist=False, default_value='latest-sharded-rank{rank}' - if is_state_dict_sharded else 'latest-rank{rank}.pt') - save_overwrite: bool = pop_config(cfg, - 'save_overwrite', - must_exist=False, - default_value=False) - save_weights_only: bool = pop_config(cfg, - 'save_weights_only', - must_exist=False, - default_value=False) + if is_state_dict_sharded else 'latest-rank{rank}.pt', + ) + save_overwrite: bool = pop_config( + cfg, + 'save_overwrite', + must_exist=False, + default_value=False, + ) + save_weights_only: bool = pop_config( + cfg, + 'save_weights_only', + must_exist=False, + default_value=False, + ) save_filename: str = pop_config( cfg, 'save_filename', must_exist=False, - default_value='ep{epoch}-ba{batch}-rank{rank}.pt') - save_interval: Union[str, int] = pop_config(cfg, - 'save_interval', - must_exist=False, - default_value='1000ba') + default_value='ep{epoch}-ba{batch}-rank{rank}.pt', + ) + save_interval: Union[str, int] = pop_config( + cfg, + 'save_interval', + must_exist=False, + default_value='1000ba', + ) save_num_checkpoints_to_keep: int = pop_config( - cfg, 'save_num_checkpoints_to_keep', must_exist=False, default_value=-1) - progress_bar = pop_config(cfg, - 'progress_bar', - must_exist=False, - default_value=False) - log_to_console: bool = pop_config(cfg, - 'log_to_console', - must_exist=False, - default_value=True) - console_log_interval: Union[int, str] = pop_config(cfg, - 'console_log_interval', - must_exist=False, - default_value='1ba') + cfg, + 'save_num_checkpoints_to_keep', + must_exist=False, + default_value=-1, + ) + progress_bar = pop_config( + cfg, + 'progress_bar', + must_exist=False, + default_value=False, + ) + log_to_console: bool = pop_config( + cfg, + 'log_to_console', + must_exist=False, + default_value=True, + ) + console_log_interval: Union[int, str] = pop_config( + cfg, + 'console_log_interval', + must_exist=False, + default_value='1ba', + ) device_train_microbatch_size: Union[str, int] = pop_config( cfg, 'device_train_microbatch_size', must_exist=False, - default_value='auto') - eval_subset_num_batches: int = pop_config(cfg, - 'eval_subset_num_batches', - must_exist=False, - default_value=-1) - eval_first: bool = pop_config(cfg, - 'eval_first', - must_exist=False, - default_value=False) - load_path: str = pop_config(cfg, - 'load_path', - must_exist=False, - default_value=None) - load_weights_only: bool = pop_config(cfg, - 'load_weights_only', - must_exist=False, - default_value=False) - load_strict_model_weights: bool = pop_config(cfg, - 'load_strict_model_weights', - must_exist=False, - default_value=True) - load_ignore_keys: Optional[List[str]] = pop_config(cfg, - 'load_ignore_keys', - must_exist=False, - default_value=None) - save_ignore_keys: Optional[List[str]] = pop_config(cfg, - 'save_ignore_keys', - must_exist=False, - default_value=None) - compile_config: Optional[Dict[str, Any]] = pop_config(cfg, - 'compile_config', - must_exist=False, - default_value=None) - metadata: Optional[Dict[str, str]] = pop_config(cfg, - 'metadata', - must_exist=False, - default_value=None, - convert=True) - should_log_config: bool = pop_config(cfg, - 'log_config', - must_exist=False, - default_value=True) + default_value='auto', + ) + eval_subset_num_batches: int = pop_config( + cfg, + 'eval_subset_num_batches', + must_exist=False, + default_value=-1, + ) + eval_first: bool = pop_config( + cfg, + 'eval_first', + must_exist=False, + default_value=False, + ) + load_path: str = pop_config( + cfg, + 'load_path', + must_exist=False, + default_value=None, + ) + load_weights_only: bool = pop_config( + cfg, + 'load_weights_only', + must_exist=False, + default_value=False, + ) + load_strict_model_weights: bool = pop_config( + cfg, + 'load_strict_model_weights', + must_exist=False, + default_value=True, + ) + load_ignore_keys: Optional[List[str]] = pop_config( + cfg, + 'load_ignore_keys', + must_exist=False, + default_value=None, + ) + save_ignore_keys: Optional[List[str]] = pop_config( + cfg, + 'save_ignore_keys', + must_exist=False, + default_value=None, + ) + compile_config: Optional[ + Dict[str, Any] + ] = pop_config(cfg, 'compile_config', must_exist=False, default_value=None) + metadata: Optional[Dict[str, str]] = pop_config( + cfg, + 'metadata', + must_exist=False, + default_value=None, + convert=True, + ) + should_log_config: bool = pop_config( + cfg, + 'log_config', + must_exist=False, + default_value=True, + ) # Enable autoresume from model checkpoints if possible autoresume_default: bool = False @@ -388,13 +470,17 @@ def main(cfg: DictConfig) -> Trainer: autoresume_default = True if cfg.get('autoresume') is None and autoresume_default: - log.info('As run_name, save_folder, and save_latest_filename are set, \ - changing autoresume default to True...') + log.info( + 'As run_name, save_folder, and save_latest_filename are set, \ + changing autoresume default to True...', + ) - autoresume: bool = pop_config(cfg, - 'autoresume', - must_exist=False, - default_value=autoresume_default) + autoresume: bool = pop_config( + cfg, + 'autoresume', + must_exist=False, + default_value=autoresume_default, + ) # Pop known unused parameters that are used as interpolation variables or # created by update_batch_size_info. @@ -408,13 +494,14 @@ def main(cfg: DictConfig) -> Trainer: # Warn users for unused parameters for key in cfg: warnings.warn( - f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary.' + f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary.', ) # Warn if fsdp is enabled but user only has 1 GPU if dist.get_world_size() == 1 and fsdp_config is not None: warnings.warn( - 'FSDP is not applicable for single-GPU training. Reverting to DDP.') + 'FSDP is not applicable for single-GPU training. Reverting to DDP.', + ) fsdp_config = None # Initialize context @@ -454,37 +541,47 @@ def main(cfg: DictConfig) -> Trainer: # Profiling profiler: Optional[Profiler] = None - profiler_cfg: Optional[DictConfig] = pop_config(cfg, - 'profiler', - must_exist=False, - convert=False, - default_value=None) + profiler_cfg: Optional[DictConfig] = pop_config( + cfg, + 'profiler', + must_exist=False, + convert=False, + default_value=None, + ) if profiler_cfg: - profiler_schedule_cfg: Dict = pop_config(profiler_cfg, - 'schedule', - must_exist=True, - convert=True) + profiler_schedule_cfg: Dict = pop_config( + profiler_cfg, + 'schedule', + must_exist=True, + convert=True, + ) profiler_schedule = cyclic_schedule(**profiler_schedule_cfg) # Only support json trace handler profiler_trace_handlers: List[TraceHandler] = [] - profiler_trace_cfg: Optional[Dict] = pop_config(profiler_cfg, - 'json_trace_handler', - must_exist=False, - default_value=None, - convert=True) + profiler_trace_cfg: Optional[Dict] = pop_config( + profiler_cfg, + 'json_trace_handler', + must_exist=False, + default_value=None, + convert=True, + ) if profiler_trace_cfg: profiler_trace_handlers.append( - JSONTraceHandler(**profiler_trace_cfg)) - profiler = Profiler(**profiler_cfg, - trace_handlers=profiler_trace_handlers, - schedule=profiler_schedule) + JSONTraceHandler(**profiler_trace_cfg), + ) + profiler = Profiler( + **profiler_cfg, + trace_handlers=profiler_trace_handlers, + schedule=profiler_schedule, + ) # Callbacks callbacks: List[Callback] = [ - build_callback(name=str(name), - kwargs=callback_cfg, - train_config=om.to_container(logged_cfg)) - for name, callback_cfg in callback_configs.items() + build_callback( + name=str(name), + kwargs=callback_cfg, + train_config=om.to_container(logged_cfg), + ) for name, callback_cfg in callback_configs.items() ] if callback_configs else [] use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks) @@ -516,7 +613,7 @@ def main(cfg: DictConfig) -> Trainer: evaluators = [] if eval_first: warnings.warn( - 'AsyncEval callback does not support eval_first=True. Ignoring.' + 'AsyncEval callback does not support eval_first=True. Ignoring.', ) eval_first = False @@ -536,10 +633,17 @@ def main(cfg: DictConfig) -> Trainer: callbacks.append(eval_gauntlet_callback) if mosaicml_logger is not None: - log_train_analytics(mosaicml_logger, model_config, train_loader_config, - eval_loader_config, callback_configs, - tokenizer_name, load_path, icl_tasks_config, - eval_gauntlet_config) + log_train_analytics( + mosaicml_logger, + model_config, + train_loader_config, + eval_loader_config, + callback_configs, + tokenizer_name, + load_path, + icl_tasks_config, + eval_gauntlet_config, + ) # Build Model log.info('Initializing model...') model = build_composer_model( @@ -557,7 +661,8 @@ def main(cfg: DictConfig) -> Trainer: else: n_params = sum(p.numel() for p in model.parameters()) n_trainable_params = sum( - p.numel() for p in model.parameters() if p.requires_grad) + p.numel() for p in model.parameters() if p.requires_grad + ) if hasattr(model, 'n_active_params'): n_active_params = model.n_active_params else: @@ -580,8 +685,10 @@ def main(cfg: DictConfig) -> Trainer: metric_name for metric_name, metric in eval_metrics.items() if not isinstance(metric, InContextLearningMetric) ] - evaluators = add_metrics_to_eval_loaders(evaluators, - non_icl_metrics) + evaluators = add_metrics_to_eval_loaders( + evaluators, + non_icl_metrics, + ) except Exception as e: if mosaicml_logger is not None: mosaicml_logger.log_exception(e) diff --git a/setup.py b/setup.py index 3954ec698e..d454f6040e 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,8 @@ # we put parens around the version so that it becomes elem 1 of the match expr = re.compile( r"""^__version__\s*=\s*['"]([0-9]+\.[0-9]+\.[0-9]+(?:\.\w+)?)['"]""", - re.MULTILINE) + re.MULTILINE, +) repo_version = expr.findall(content)[0] # Use repo README for PyPi description @@ -122,14 +123,18 @@ 'grouped-gemm==0.1.4', ] -extra_deps['all-cpu'] = set(dep for key, deps in extra_deps.items() - for dep in deps - if 'gpu' not in key and 'megablocks' not in key) -extra_deps['all'] = set(dep for key, deps in extra_deps.items() for dep in deps - if key not in {'gpu-flash2', 'all-cpu'}) -extra_deps['all-flash2'] = set(dep for key, deps in extra_deps.items() - for dep in deps - if key not in {'gpu', 'all', 'all-cpu'}) +extra_deps['all-cpu'] = { + dep for key, deps in extra_deps.items() for dep in deps + if 'gpu' not in key and 'megablocks' not in key +} +extra_deps['all'] = { + dep for key, deps in extra_deps.items() for dep in deps + if key not in {'gpu-flash2', 'all-cpu'} +} +extra_deps['all-flash2'] = { + dep for key, deps in extra_deps.items() for dep in deps + if key not in {'gpu', 'all', 'all-cpu'} +} setup( name=_PACKAGE_NAME, @@ -144,7 +149,8 @@ 'llmfoundry': ['py.typed'], }, packages=setuptools.find_packages( - exclude=['.github*', 'mcli*', 'scripts*', 'tests*']), + exclude=['.github*', 'mcli*', 'scripts*', 'tests*'], + ), classifiers=classifiers, install_requires=install_requires, extras_require=extra_deps, diff --git a/tests/a_scripts/data_prep/test_convert_dataset_hf.py b/tests/a_scripts/data_prep/test_convert_dataset_hf.py index f226b0a4be..4c5d1a6bba 100644 --- a/tests/a_scripts/data_prep/test_convert_dataset_hf.py +++ b/tests/a_scripts/data_prep/test_convert_dataset_hf.py @@ -23,6 +23,8 @@ def test_download_script_from_api(tmp_path: Path): 'bos_text': None, 'eos_text': None, 'no_wrap': False, - 'num_workers': None - })) + 'num_workers': None, + }, + ), + ) assert os.path.exists(path) diff --git a/tests/a_scripts/data_prep/test_convert_dataset_json.py b/tests/a_scripts/data_prep/test_convert_dataset_json.py index 179b8a701b..912e44cd0c 100644 --- a/tests/a_scripts/data_prep/test_convert_dataset_json.py +++ b/tests/a_scripts/data_prep/test_convert_dataset_json.py @@ -22,6 +22,8 @@ def test_json_script_from_api(tmp_path: Path): 'bos_text': None, 'eos_text': None, 'no_wrap': False, - 'num_workers': None - })) + 'num_workers': None, + }, + ), + ) assert os.path.exists(path) diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 7839455563..e4619b8a56 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -9,9 +9,12 @@ from typing import Any from unittest.mock import MagicMock, mock_open, patch -from scripts.data_prep.convert_delta_to_json import (download, fetch_DT, - iterative_combine_jsons, - run_query) +from scripts.data_prep.convert_delta_to_json import ( + download, + fetch_DT, + iterative_combine_jsons, + run_query, +) class TestConvertDeltaToJsonl(unittest.TestCase): @@ -21,9 +24,14 @@ class TestConvertDeltaToJsonl(unittest.TestCase): @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') @patch('scripts.data_prep.convert_delta_to_json.fetch') @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - def test_stream_delta_to_json(self, mock_workspace_client: Any, - mock_fetch: Any, mock_combine_jsons: Any, - mock_makedirs: Any, mock_sql_connect: Any): + def test_stream_delta_to_json( + self, + mock_workspace_client: Any, + mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_sql_connect: Any, + ): args = MagicMock() args.delta_table_name = 'test_table' @@ -40,22 +48,29 @@ def test_stream_delta_to_json(self, mock_workspace_client: Any, mock_cluster_get = MagicMock() mock_cluster_get.return_value = MagicMock( - spark_version='14.1.0-scala2.12') + spark_version='14.1.0-scala2.12', + ) mock_workspace_client.return_value.clusters.get = mock_cluster_get fetch_DT(args) - mock_sql_connect.assert_called_once_with(server_hostname='test_host', - http_path='test_path', - access_token='test_token') + mock_sql_connect.assert_called_once_with( + server_hostname='test_host', + http_path='test_path', + access_token='test_token', + ) mock_makedirs.assert_called_once_with('/path/to/jsonl', exist_ok=True) mock_fetch.assert_called_once() mock_combine_jsons.assert_called_once_with( - '/path/to/jsonl', '/path/to/jsonl/combined.jsonl') + '/path/to/jsonl', + '/path/to/jsonl/combined.jsonl', + ) @patch('scripts.data_prep.convert_delta_to_json.os.listdir') - @patch('builtins.open', - new_callable=mock_open, - read_data='{"key": "value"}') + @patch( + 'builtins.open', + new_callable=mock_open, + read_data='{"key": "value"}', + ) def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any): mock_listdir.return_value = ['file1.jsonl', 'file2.jsonl'] json_directory = '/fake/dir' @@ -92,10 +107,12 @@ def test_run_query_dbconnect(self, mock_spark: Any): mock_cursor = None mock_spark.sql.return_value.collect.return_value = 'result' - result = run_query('SELECT * FROM table', - method, - cursor=mock_cursor, - spark=mock_spark) + result = run_query( + 'SELECT * FROM table', + method, + cursor=mock_cursor, + spark=mock_spark, + ) mock_spark.sql.assert_called_once_with('SELECT * FROM table') self.assertEqual(result, 'result') @@ -106,38 +123,53 @@ def test_run_query_dbsql(self, mock_cursor: Any): mock_cursor.fetchall.return_value = 'result' mock_spark = None - result = run_query('SELECT * FROM table', - method, - cursor=mock_cursor, - spark=mock_spark) + result = run_query( + 'SELECT * FROM table', + method, + cursor=mock_cursor, + spark=mock_spark, + ) mock_cursor.execute.assert_called_once_with('SELECT * FROM table') self.assertEqual(result, 'result') @patch('scripts.data_prep.convert_delta_to_json.requests.get') @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') - @patch('scripts.data_prep.convert_delta_to_json.os.path.join', - return_value='/fake/path/part_1.jsonl') - @patch('scripts.data_prep.convert_delta_to_json.time.sleep' - ) # Mock sleep to speed up the test - def test_download_success(self, mock_sleep: Any, mock_join: Any, - mock_to_json: Any, mock_get: Any): + @patch( + 'scripts.data_prep.convert_delta_to_json.os.path.join', + return_value='/fake/path/part_1.jsonl', + ) + @patch( + 'scripts.data_prep.convert_delta_to_json.time.sleep', + ) # Mock sleep to speed up the test + def test_download_success( + self, + mock_sleep: Any, + mock_join: Any, + mock_to_json: Any, + mock_get: Any, + ): mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = [['val1.1', 'val1.2'], ['val2.1', 'val2.2']] mock_get.return_value = mock_response - download(1, - 'http://fakeurl.com/data', - '/fake/path', ['A', 'B'], - resp_format='json') + download( + 1, + 'http://fakeurl.com/data', + '/fake/path', + ['A', 'B'], + resp_format='json', + ) mock_get.assert_called_with('http://fakeurl.com/data') mock_join.assert_called_with('/fake/path', 'part_1.jsonl') - mock_to_json.assert_called_with('/fake/path/part_1.jsonl', - orient='records', - lines=True) + mock_to_json.assert_called_with( + '/fake/path/part_1.jsonl', + orient='records', + lines=True, + ) mock_get.assert_called_once_with('http://fakeurl.com/data') @@ -147,10 +179,15 @@ def test_download_success(self, mock_sleep: Any, mock_join: Any, @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') @patch('scripts.data_prep.convert_delta_to_json.fetch') - def test_dbconnect_called(self, mock_fetch: Any, mock_combine_jsons: Any, - mock_makedirs: Any, mock_workspace_client: Any, - mock_databricks_session: Any, - mock_sql_connect: Any): + def test_dbconnect_called( + self, + mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any, + ): args = MagicMock() @@ -175,7 +212,8 @@ def test_dbconnect_called(self, mock_fetch: Any, mock_combine_jsons: Any, mock_databricks_session.builder.remote.assert_called_once_with( host=args.DATABRICKS_HOST, token=args.DATABRICKS_TOKEN, - cluster_id=args.cluster_id) + cluster_id=args.cluster_id, + ) @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') @@ -183,12 +221,15 @@ def test_dbconnect_called(self, mock_fetch: Any, mock_combine_jsons: Any, @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') @patch('scripts.data_prep.convert_delta_to_json.fetch') - def test_sqlconnect_called_dbr13(self, mock_fetch: Any, - mock_combine_jsons: Any, - mock_makedirs: Any, - mock_workspace_client: Any, - mock_databricks_session: Any, - mock_sql_connect: Any): + def test_sqlconnect_called_dbr13( + self, + mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any, + ): args = MagicMock() @@ -208,7 +249,8 @@ def test_sqlconnect_called_dbr13(self, mock_fetch: Any, mock_sql_connect.assert_called_once_with( server_hostname=args.DATABRICKS_HOST, http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN) + access_token=args.DATABRICKS_TOKEN, + ) @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') @@ -216,12 +258,15 @@ def test_sqlconnect_called_dbr13(self, mock_fetch: Any, @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') @patch('scripts.data_prep.convert_delta_to_json.fetch') - def test_sqlconnect_called_dbr14(self, mock_fetch: Any, - mock_combine_jsons: Any, - mock_makedirs: Any, - mock_workspace_client: Any, - mock_databricks_session: Any, - mock_sql_connect: Any): + def test_sqlconnect_called_dbr14( + self, + mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any, + ): args = MagicMock() @@ -241,7 +286,8 @@ def test_sqlconnect_called_dbr14(self, mock_fetch: Any, mock_sql_connect.assert_called_once_with( server_hostname=args.DATABRICKS_HOST, http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN) + access_token=args.DATABRICKS_TOKEN, + ) @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') @@ -249,12 +295,15 @@ def test_sqlconnect_called_dbr14(self, mock_fetch: Any, @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') @patch('scripts.data_prep.convert_delta_to_json.fetch') - def test_sqlconnect_called_https(self, mock_fetch: Any, - mock_combine_jsons: Any, - mock_makedirs: Any, - mock_workspace_client: Any, - mock_databricks_session: Any, - mock_sql_connect: Any): + def test_sqlconnect_called_https( + self, + mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any, + ): args = MagicMock() @@ -274,7 +323,8 @@ def test_sqlconnect_called_https(self, mock_fetch: Any, mock_sql_connect.assert_called_once_with( server_hostname='test-host', http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN) + access_token=args.DATABRICKS_TOKEN, + ) @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') @@ -282,9 +332,15 @@ def test_sqlconnect_called_https(self, mock_fetch: Any, @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') @patch('scripts.data_prep.convert_delta_to_json.fetch') - def test_serverless(self, mock_fetch: Any, mock_combine_jsons: Any, - mock_makedirs: Any, mock_workspace_client: Any, - mock_databricks_session: Any, mock_sql_connect: Any): + def test_serverless( + self, + mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any, + ): args = MagicMock() diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index bd96de695c..df4309e13d 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -14,20 +14,28 @@ from streaming import StreamingDataset from transformers import AutoTokenizer -from llmfoundry.utils.exceptions import (InputFolderMissingDataError, - OutputFolderNotEmptyError) -from scripts.data_prep.convert_text_to_mds import (DONE_FILENAME, - convert_text_to_mds, - download_and_convert, - is_already_processed, - merge_shard_groups, - write_done_file) +from llmfoundry.utils.exceptions import ( + InputFolderMissingDataError, + OutputFolderNotEmptyError, +) +from scripts.data_prep.convert_text_to_mds import ( + DONE_FILENAME, + convert_text_to_mds, + download_and_convert, + is_already_processed, + merge_shard_groups, + write_done_file, +) class MockObjectStore(): - def __init__(self, remote_folder: str, n_text_files: int, - text_content: str): + def __init__( + self, + remote_folder: str, + n_text_files: int, + text_content: str, + ): os.makedirs(remote_folder, exist_ok=True) for i in range(n_text_files): with open(os.path.join(remote_folder, f'test{i}.txt'), 'w') as f: @@ -36,16 +44,19 @@ def __init__(self, remote_folder: str, n_text_files: int, self.remote_folder = remote_folder self.n_text_files = n_text_files - def download_object(self, - object_name: str, - filename: str, - overwrite: bool = False): + def download_object( + self, + object_name: str, + filename: str, + overwrite: bool = False, + ): dirname = os.path.dirname(filename) if dirname: os.makedirs(dirname, exist_ok=True) with open( - os.path.join(self.remote_folder, os.path.basename(object_name)), - 'rb') as remote_file, open(filename, 'wb') as local_file: + os.path.join(self.remote_folder, os.path.basename(object_name)), + 'rb', + ) as remote_file, open(filename, 'wb') as local_file: local_file.write(remote_file.read()) def list_objects(self, prefix: str) -> List[str]: @@ -53,8 +64,9 @@ def list_objects(self, prefix: str) -> List[str]: def upload_object(self, object_name: str, filename: str): with open( - os.path.join(self.remote_folder, os.path.basename(object_name)), - 'wb') as remote_file, open(filename, 'rb') as local_file: + os.path.join(self.remote_folder, os.path.basename(object_name)), + 'wb', + ) as remote_file, open(filename, 'rb') as local_file: remote_file.write(local_file.read()) @@ -72,16 +84,25 @@ def _assert_files_exist(prefix: str, files: List[str]): @pytest.mark.parametrize('processes', [1, 2, 3]) @patch.object(ProcessPoolExecutor, 'map', new=Mock(wraps=_mock_map)) @patch( - 'scripts.data_prep.convert_text_to_mds.maybe_create_object_store_from_uri') + 'scripts.data_prep.convert_text_to_mds.maybe_create_object_store_from_uri', +) @patch('scripts.data_prep.convert_text_to_mds.parse_uri') -@patch('scripts.data_prep.convert_text_to_mds.download_and_convert', - wraps=download_and_convert) -@patch('scripts.data_prep.convert_text_to_mds.merge_shard_groups', - wraps=merge_shard_groups) -def test_single_and_multi_process(merge_shard_groups: Mock, - download_and_convert: Mock, parse_uri: Mock, - maybe_create_object_store_from_uri: Mock, - tmp_path: pathlib.Path, processes: int): +@patch( + 'scripts.data_prep.convert_text_to_mds.download_and_convert', + wraps=download_and_convert, +) +@patch( + 'scripts.data_prep.convert_text_to_mds.merge_shard_groups', + wraps=merge_shard_groups, +) +def test_single_and_multi_process( + merge_shard_groups: Mock, + download_and_convert: Mock, + parse_uri: Mock, + maybe_create_object_store_from_uri: Mock, + tmp_path: pathlib.Path, + processes: int, +): remote_folder = os.path.join(tmp_path, 'remote') text_content = 'HELLO WORLD ' * 500 tokenizer_name = 'mosaicml/mpt-7b' @@ -89,7 +110,8 @@ def test_single_and_multi_process(merge_shard_groups: Mock, concat_tokens = 2048 mock_object_store = Mock( - wraps=MockObjectStore(remote_folder, n_text_files, text_content)) + wraps=MockObjectStore(remote_folder, n_text_files, text_content), + ) maybe_create_object_store_from_uri.return_value = mock_object_store parse_uri.return_value = ('s3', 'fake-test-bucket', str(remote_folder)) @@ -128,8 +150,10 @@ def call_convert_text_to_mds() -> None: # Check that correct output files exist shards = [f'shard.0000{i}.mds.zstd' for i in range(processes)] - _assert_files_exist(prefix=remote_folder, - files=['index.json', DONE_FILENAME] + shards) + _assert_files_exist( + prefix=remote_folder, + files=['index.json', DONE_FILENAME] + shards, + ) call_convert_text_to_mds() @@ -154,10 +178,12 @@ def call_convert_text_to_mds() -> None: # Compute the expected number of tokens tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) tokens_per_file = len(tokenizer(text_content)['input_ids']) - files_per_process = [n_text_files // processes - ] * processes # Distrubte the files equally + files_per_process = [ + n_text_files // processes, + ] * processes # Distrubte the files equally files_per_process[ - 0] += n_text_files % processes # Give one of the processes the remainder + 0 + ] += n_text_files % processes # Give one of the processes the remainder # expected number of tokens accounts for last tokens dropped by ConcatTokensDataset expected_n_tokens = sum([ ((n_files * tokens_per_file) // concat_tokens) * concat_tokens @@ -222,8 +248,10 @@ def call_convert_text_to_mds(reprocess: bool): def test_input_folder_not_exist(tmp_path: pathlib.Path): - with pytest.raises(InputFolderMissingDataError, - match='No text files were found'): + with pytest.raises( + InputFolderMissingDataError, + match='No text files were found', + ): convert_text_to_mds( tokenizer_name='mosaicml/mpt-7b', output_folder=str(tmp_path / 'output'), @@ -245,17 +273,29 @@ def test_is_already_processed(tmp_path: pathlib.Path): args_str = 'Namespace(x = 5)' object_names = ['test0.txt', 'test1.txt'] - assert not is_already_processed(tmp_path_str, args_str, - object_names) # Done file doesn't exist + assert not is_already_processed( + tmp_path_str, + args_str, + object_names, + ) # Done file doesn't exist write_done_file(tmp_path_str, args_str, object_names) - assert is_already_processed(tmp_path_str, args_str, - object_names) # Args and names match + assert is_already_processed( + tmp_path_str, + args_str, + object_names, + ) # Args and names match write_done_file(tmp_path_str, args_str, object_names + ['test2.txt']) - assert not is_already_processed(tmp_path_str, args_str, - object_names) # Object names differ + assert not is_already_processed( + tmp_path_str, + args_str, + object_names, + ) # Object names differ write_done_file(tmp_path_str, 'Namespace()', object_names) - assert not is_already_processed(tmp_path_str, args_str, - object_names) # Argument strings differ + assert not is_already_processed( + tmp_path_str, + args_str, + object_names, + ) # Argument strings differ diff --git a/tests/a_scripts/eval/test_eval.py b/tests/a_scripts/eval/test_eval.py index 63c4ea8261..2c7f81a8b2 100644 --- a/tests/a_scripts/eval/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -14,8 +14,11 @@ from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model from scripts.eval.eval import main # noqa: E402 -from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall, - gpt_tiny_cfg) +from tests.data_utils import ( + create_arxiv_dataset, + create_c4_dataset_xxsmall, + gpt_tiny_cfg, +) @pytest.fixture(autouse=True) @@ -44,8 +47,10 @@ def mock_saved_model_path(eval_cfg: Union[om.ListConfig, om.DictConfig]): device = 'cpu' model_cfg.model.init_device = device # build tokenizer - tokenizer = build_tokenizer(model_cfg.tokenizer.name, - model_cfg.tokenizer.get('kwargs', {})) + tokenizer = build_tokenizer( + model_cfg.tokenizer.name, + model_cfg.tokenizer.get('kwargs', {}), + ) # build model model = build_composer_model( name=model_cfg.model.name, @@ -63,8 +68,11 @@ def mock_saved_model_path(eval_cfg: Union[om.ListConfig, om.DictConfig]): os.remove(saved_model_path) -def test_icl_eval(eval_cfg: Union[om.ListConfig, om.DictConfig], capfd: Any, - mock_saved_model_path: Any): +def test_icl_eval( + eval_cfg: Union[om.ListConfig, om.DictConfig], + capfd: Any, + mock_saved_model_path: Any, +): eval_cfg.models[0].load_path = mock_saved_model_path assert isinstance(eval_cfg, om.DictConfig) main(eval_cfg) @@ -75,8 +83,11 @@ def test_icl_eval(eval_cfg: Union[om.ListConfig, om.DictConfig], capfd: Any, assert expected_results in out -def test_loader_eval(capfd: Any, mock_saved_model_path: Any, - tmp_path: pathlib.Path): +def test_loader_eval( + capfd: Any, + mock_saved_model_path: Any, + tmp_path: pathlib.Path, +): c4_dataset_name = create_c4_dataset_xxsmall(tmp_path) @@ -92,8 +103,8 @@ def test_loader_eval(capfd: Any, mock_saved_model_path: Any, 'eval/local_data/language_understanding/lambada_openai_small.jsonl', 'num_fewshot': [0], 'icl_task_type': - 'language_modeling' - }) + 'language_modeling', + }), ]) # convert the model from a training to eval model @@ -101,7 +112,7 @@ def test_loader_eval(capfd: Any, mock_saved_model_path: Any, eval_model = { 'model_name': model.get('name'), 'model': model, - 'load_path': mock_saved_model_path + 'load_path': mock_saved_model_path, } tokenizer = test_cfg.pop('tokenizer') @@ -116,8 +127,10 @@ def test_loader_eval(capfd: Any, mock_saved_model_path: Any, arxiv_dataset_name = create_arxiv_dataset(tmp_path) second_eval_loader.data_local = arxiv_dataset_name second_eval_loader.label = 'arxiv' - test_cfg.eval_loader = om.OmegaConf.create( - [first_eval_loader, second_eval_loader]) + test_cfg.eval_loader = om.OmegaConf.create([ + first_eval_loader, + second_eval_loader, + ]) test_cfg.max_duration = '1ba' test_cfg.eval_interval = '1ba' @@ -140,19 +153,28 @@ def test_loader_eval(capfd: Any, mock_saved_model_path: Any, # Checks for first eval dataloader assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list) + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], + list, + ) assert len( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0 + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + ) > 0 assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple) + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + tuple, + ) # Checks for second eval dataloader assert 'metrics/eval/arxiv/LanguageCrossEntropy' in inmemorylogger.data.keys( ) assert isinstance( - inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], list) + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], + list, + ) assert len( - inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1]) > 0 + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1], + ) > 0 assert isinstance( inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1], - tuple) + tuple, + ) diff --git a/tests/a_scripts/eval/test_eval_inputs.py b/tests/a_scripts/eval/test_eval_inputs.py index 8694546c4f..030fc434bf 100644 --- a/tests/a_scripts/eval/test_eval_inputs.py +++ b/tests/a_scripts/eval/test_eval_inputs.py @@ -37,14 +37,18 @@ def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None: ] mandatory_configs = ['models', 'icl_tasks'] for p in mandatory_params + mandatory_configs: - with pytest.raises((omegaconf.errors.ConfigKeyError, - omegaconf.errors.InterpolationKeyError)): + with pytest.raises(( + omegaconf.errors.ConfigKeyError, + omegaconf.errors.InterpolationKeyError, + )): cfg[p + '-mispelled'] = cfg.pop(p) main(cfg) cfg[p] = cfg.pop(p + '-mispelled') - def test_optional_mispelled_params_raise_warning(self, - cfg: DictConfig) -> None: + def test_optional_mispelled_params_raise_warning( + self, + cfg: DictConfig, + ) -> None: """Check that warnings are raised for optional mispelled parameters.""" optional_params = [ 'seed', @@ -66,8 +70,10 @@ def test_optional_mispelled_params_raise_warning(self, main(cfg) except: pass - assert any(f'Unused parameter {updated_param} found in cfg.' in - str(warning.message) for warning in warning_list) + assert any( + f'Unused parameter {updated_param} found in cfg.' in + str(warning.message) for warning in warning_list + ) # restore configs. cfg = copy.deepcopy(old_cfg) @@ -85,7 +91,10 @@ def cfg(self, foundry_dir: str) -> DictConfig: test_cfg = om.load(config) test_cfg.icl_tasks[0].dataset_uri = os.path.join( - foundry_dir, 'scripts', test_cfg.icl_tasks[0].dataset_uri) + foundry_dir, + 'scripts', + test_cfg.icl_tasks[0].dataset_uri, + ) # make tests use cpu initialized transformer models only test_cfg.models[0].model.init_device = 'cpu' diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index ee01fb743d..f21666fdbe 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -26,8 +26,11 @@ from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename from llmfoundry.data.finetuning import build_finetuning_dataloader from llmfoundry.models.mpt import MPTConfig -from llmfoundry.utils.builders import (build_composer_model, build_optimizer, - build_tokenizer) +from llmfoundry.utils.builders import ( + build_composer_model, + build_optimizer, + build_tokenizer, +) from llmfoundry.utils.config_utils import process_init_device from scripts.inference.convert_composer_to_hf import convert_composer_to_hf from tests.data_utils import make_tiny_ft_dataset @@ -45,8 +48,10 @@ def _save_model_mock(*args: Any, path: str, **kwargs: Any): os.makedirs(path, exist_ok=True) -def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase, - tokenizer2: PreTrainedTokenizerBase): +def check_hf_tokenizer_equivalence( + tokenizer1: PreTrainedTokenizerBase, + tokenizer2: PreTrainedTokenizerBase, +): """WARNING: Parameters are updated within the check so don't call check_hf_tokenizer_equivalence on the same params more than once @@ -62,12 +67,15 @@ def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase, # we only care about the file and class name, not the full import path assert str(type(tokenizer1)).split('.')[-2:] == str( - type(tokenizer2)).split('.')[-2:] + type(tokenizer2), + ).split('.')[-2:] expected_tokenizer_output = tokenizer2( - 'This is some text that should get tokenizer !? @ totallyarealtoken') + 'This is some text that should get tokenizer !? @ totallyarealtoken', + ) actual_tokenizer_output = tokenizer1( - 'This is some text that should get tokenizer !? @ totallyarealtoken') + 'This is some text that should get tokenizer !? @ totallyarealtoken', + ) assert expected_tokenizer_output == actual_tokenizer_output # we remove the actual _tokenizer object because it is an instantiated object and so does not pass equality @@ -86,8 +94,10 @@ def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase, tokenizer2.__dict__.pop('tokens_trie') # extra key that is not important - if hasattr(tokenizer1, 'deprecation_warnings') or hasattr( - tokenizer2, 'deprecation_warnings'): + if hasattr( + tokenizer1, + 'deprecation_warnings', + ) or hasattr(tokenizer2, 'deprecation_warnings'): tokenizer1.__dict__.pop('deprecation_warnings') tokenizer2.__dict__.pop('deprecation_warnings') @@ -140,9 +150,13 @@ def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase, # The tokenizer name is changed in transformers 4.31 when changing the tokenizer mapping, so we remove it and compare # if necessary. Checks whether the names are subsets of each other. tokenizer1_name = tokenizer1.__dict__['init_kwargs'].get( - 'auto_map', {}).get('AutoTokenizer', [None])[0] + 'auto_map', + {}, + ).get('AutoTokenizer', [None])[0] tokenizer2_name = tokenizer2.__dict__['init_kwargs'].get( - 'auto_map', {}).get('AutoTokenizer', [None])[0] + 'auto_map', + {}, + ).get('AutoTokenizer', [None])[0] if tokenizer1_name is not None and tokenizer2_name is not None: assert tokenizer1_name in tokenizer2_name or tokenizer2_name in tokenizer1_name tokenizer1.__dict__['init_kwargs'].pop('auto_map', None) @@ -165,8 +179,8 @@ def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase, tokenizer2.__dict__['init_kwargs'].pop('added_tokens_decoder', None) # If the additional special tokens are the same (or a subset of each other), or if one of them is empty, then we are good assert additional_special_tokens_1.issubset( - additional_special_tokens_2) or additional_special_tokens_2.issubset( - additional_special_tokens_1) + additional_special_tokens_2, + ) or additional_special_tokens_2.issubset(additional_special_tokens_1) # The special token attributes may be strings or they may be AddedToken objects, so we just check string values # First check that they have the same attrs @@ -200,9 +214,11 @@ def remove_moe_world_size(config: MPTConfig): config.ffn_config.pop('moe_world_size') -def check_hf_model_equivalence(model1: PreTrainedModel, - model2: PreTrainedModel, - just_lora: bool = False): +def check_hf_model_equivalence( + model1: PreTrainedModel, + model2: PreTrainedModel, + just_lora: bool = False, +): remove_moe_world_size(model1.config) remove_moe_world_size(model2.config) @@ -227,12 +243,14 @@ def check_hf_model_equivalence(model1: PreTrainedModel, assert auto_map_1 == {'AutoConfig': 'configuration_mpt.MPTConfig'} assert auto_map_2 == { 'AutoConfig': 'configuration_mpt.MPTConfig', - 'AutoModelForCausalLM': 'modeling_mpt.MPTForCausalLM' + 'AutoModelForCausalLM': 'modeling_mpt.MPTForCausalLM', } assert expected_model_config_dict == new_model_config_dict - for (n1, p1), (_, p2) in zip(model1.named_parameters(), - model2.named_parameters()): + for (n1, p1), ( + _, + p2, + ) in zip(model1.named_parameters(), model2.named_parameters()): if not just_lora or 'lora' in n1: assert torch.equal(p1.cpu(), p2.cpu()) @@ -246,16 +264,22 @@ def delete_transformers_cache(): hf_cache_home = os.path.expanduser( os.getenv( 'HF_HOME', - os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), - 'huggingface'))) - HF_MODULES_CACHE = os.getenv('HF_MODULES_CACHE', - os.path.join(hf_cache_home, 'modules')) + os.path.join( + os.getenv('XDG_CACHE_HOME', '~/.cache'), + 'huggingface', + ), + ), + ) + HF_MODULES_CACHE = os.getenv( + 'HF_MODULES_CACHE', + os.path.join(hf_cache_home, 'modules'), + ) if os.path.exists(HF_MODULES_CACHE) and os.path.isdir(HF_MODULES_CACHE): shutil.rmtree(HF_MODULES_CACHE) def get_config( - conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml' + conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', ) -> DictConfig: os.environ['TOKENIZERS_PARALLELISM'] = 'false' with open(conf_path) as f: @@ -272,7 +296,8 @@ def test_callback_inits(): hf_checkpointer = HuggingFaceCheckpointer( save_folder='test', save_interval='1ba', - mlflow_registered_model_name='test_model_name') + mlflow_registered_model_name='test_model_name', + ) assert hf_checkpointer.mlflow_logging_config['task'] == 'llm/v1/completions' @@ -301,15 +326,25 @@ def is_alive(self) -> bool: @pytest.mark.parametrize('log_to_mlflow', [True, False]) @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', - [('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)]) + [('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)], +) @patch('os.cpu_count', MagicMock(return_value=1)) -@patch('llmfoundry.callbacks.hf_checkpointer.SpawnProcess', - new=MockSpawnProcess) +@patch( + 'llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=MockSpawnProcess, +) def test_huggingface_conversion_callback_interval( - tmp_path: pathlib.Path, log_to_mlflow: bool, hf_save_interval: str, - save_interval: str, max_duration: str, expected_hf_checkpoints: int, - expected_normal_checkpoints: int, tiny_ft_dataloader: DataLoader, - mpt_tokenizer: PreTrainedTokenizerBase, build_tiny_mpt: Callable): + tmp_path: pathlib.Path, + log_to_mlflow: bool, + hf_save_interval: str, + save_interval: str, + max_duration: str, + expected_hf_checkpoints: int, + expected_normal_checkpoints: int, + tiny_ft_dataloader: DataLoader, + mpt_tokenizer: PreTrainedTokenizerBase, + build_tiny_mpt: Callable, +): delete_transformers_cache() dist.initialize_dist(get_device('gpu')) @@ -332,8 +367,11 @@ def test_huggingface_conversion_callback_interval( optimizer_config = _OPTIMIZER_CFG() optimizer_name = optimizer_config.pop('name') - optimizer = build_optimizer(original_model, optimizer_name, - optimizer_config) + optimizer = build_optimizer( + original_model, + optimizer_name, + optimizer_config, + ) mlflow_logger_mock = MagicMock(spec=MLFlowLogger) mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} @@ -375,17 +413,20 @@ def test_huggingface_conversion_callback_interval( name for name in os.listdir(os.path.join(tmp_path, 'checkpoints')) if name != 'huggingface' ] - huggingface_checkpoints = [ - name for name in os.listdir( - os.path.join(tmp_path, 'checkpoints', 'huggingface')) - ] + huggingface_checkpoints = list( + os.listdir(os.path.join(tmp_path, 'checkpoints', 'huggingface')), + ) assert len(normal_checkpoints) == expected_normal_checkpoints assert len(huggingface_checkpoints) == expected_hf_checkpoints # Load the last huggingface checkpoint loaded_model = transformers.AutoModelForCausalLM.from_pretrained( - os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba{batches_per_epoch}'), + os.path.join( + tmp_path, + 'checkpoints', + 'huggingface', + f'ba{batches_per_epoch}', + ), trust_remote_code=True, ) @@ -403,20 +444,29 @@ def test_huggingface_conversion_callback_interval( loaded_model.config.init_device = original_model.model.config.init_device loaded_tokenizer = transformers.AutoTokenizer.from_pretrained( - os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba{batches_per_epoch}'), + os.path.join( + tmp_path, + 'checkpoints', + 'huggingface', + f'ba{batches_per_epoch}', + ), trust_remote_code=True, ) - check_hf_model_equivalence(trainer.state.model.model.to(precision), - loaded_model) + check_hf_model_equivalence( + trainer.state.model.model.to(precision), + loaded_model, + ) check_hf_tokenizer_equivalence(mpt_tokenizer, loaded_tokenizer) delete_transformers_cache() -def _get_model_and_tokenizer(model: str, max_seq_len: int, - tie_word_embeddings: bool): +def _get_model_and_tokenizer( + model: str, + max_seq_len: int, + tie_word_embeddings: bool, +): if model == 'mpt': model_cfg = { 'name': 'mpt_causal_lm', @@ -480,7 +530,7 @@ def _get_model_and_tokenizer(model: str, max_seq_len: int, assert tie_word_embeddings is None if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: pytest.skip( - 'The CI cluster does not have access to the Llama models, so skip this test.' + 'The CI cluster does not have access to the Llama models, so skip this test.', ) model_cfg = { 'name': 'hf_causal_lm', @@ -500,8 +550,10 @@ def _get_model_and_tokenizer(model: str, max_seq_len: int, return model_cfg, tokenizer_name -def _assert_mlflow_logger_calls(mlflow_logger_mock: MagicMock, - peft_config: Optional[dict] = None): +def _assert_mlflow_logger_calls( + mlflow_logger_mock: MagicMock, + peft_config: Optional[dict] = None, +): if dist.get_global_rank() == 0: assert mlflow_logger_mock.save_model.call_count == 1 if peft_config is not None: @@ -515,7 +567,7 @@ def _assert_mlflow_logger_calls(mlflow_logger_mock: MagicMock, import numpy as np default_input_example = { - 'prompt': np.array(['What is Machine Learning?']) + 'prompt': np.array(['What is Machine Learning?']), } expectation = { @@ -524,7 +576,7 @@ def _assert_mlflow_logger_calls(mlflow_logger_mock: MagicMock, 'path': ANY, 'task': 'llm/v1/completions', 'input_example': default_input_example, - 'metadata': {} + 'metadata': {}, } mlflow_logger_mock.save_model.assert_called_with(**expectation) assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 @@ -563,22 +615,24 @@ def _get_dataloader_cfg(tiny_dataset_folder_path: str, max_seq_len: int): 'pin_memory': False, 'prefetch_factor': None, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } return dataloader_cfg -def _assert_checkpoint_equivalence(tmp_path: pathlib.Path, - expected_normal_checkpoints: int, - expected_hf_checkpoints: int, - trainer: Trainer, - batches_per_epoch: int, - precision: torch.dtype, - model: str, - tokenizer: PreTrainedTokenizerBase, - original_model: ComposerModel, - fsdp_state_dict_type: Optional[str] = None, - peft_config: Optional[dict] = None): +def _assert_checkpoint_equivalence( + tmp_path: pathlib.Path, + expected_normal_checkpoints: int, + expected_hf_checkpoints: int, + trainer: Trainer, + batches_per_epoch: int, + precision: torch.dtype, + model: str, + tokenizer: PreTrainedTokenizerBase, + original_model: ComposerModel, + fsdp_state_dict_type: Optional[str] = None, + peft_config: Optional[dict] = None, +): """Asserts the equivalence of checkpoints. Asserts equivalence of checkpoints between the original mpt model and the converted hf model. @@ -605,14 +659,18 @@ def _assert_checkpoint_equivalence(tmp_path: pathlib.Path, name for name in os.listdir(os.path.join(tmp_path, 'checkpoints')) if name != 'huggingface' ] - huggingface_checkpoints = [ - name for name in os.listdir( - os.path.join(tmp_path, 'checkpoints', 'huggingface')) - ] + huggingface_checkpoints = list( + os.listdir(os.path.join(tmp_path, 'checkpoints', 'huggingface')), + ) checkpoint_files = os.listdir( - os.path.join(tmp_path, 'checkpoints', 'huggingface', - huggingface_checkpoints[-1])) + os.path.join( + tmp_path, + 'checkpoints', + 'huggingface', + huggingface_checkpoints[-1], + ), + ) if peft_config is not None: assert 'adapter_config.json' in checkpoint_files assert 'adapter_model.safetensors' in checkpoint_files @@ -625,23 +683,31 @@ def _assert_checkpoint_equivalence(tmp_path: pathlib.Path, with patch.dict('sys.modules', {'flash_attn': None}): if peft_config is not None: composer_model = trainer.state.model.module if trainer.state.is_model_ddp else trainer.state.model - composer_model.model.base_model.save_pretrained(tmp_path / - 'base-model') + composer_model.model.base_model.save_pretrained( + tmp_path / 'base-model', + ) - checkpoint_path = os.path.join(tmp_path, 'checkpoints', - 'huggingface', - f'ba{batches_per_epoch}') + checkpoint_path = os.path.join( + tmp_path, + 'checkpoints', + 'huggingface', + f'ba{batches_per_epoch}', + ) if peft_config is not None: - with open(os.path.join(checkpoint_path, - 'adapter_config.json')) as _f: + with open( + os.path.join(checkpoint_path, 'adapter_config.json'), + ) as _f: adapter_config = json.load(_f) - adapter_config['base_model_name_or_path'] = str(tmp_path / - 'base-model') + adapter_config['base_model_name_or_path'] = str( + tmp_path / 'base-model', + ) - with open(os.path.join(checkpoint_path, 'adapter_config.json'), - 'w') as _f: + with open( + os.path.join(checkpoint_path, 'adapter_config.json'), + 'w', + ) as _f: json.dump(adapter_config, _f) # Load the last huggingface checkpoint @@ -667,8 +733,12 @@ def _assert_checkpoint_equivalence(tmp_path: pathlib.Path, loaded_model.config.init_device = original_model.model.config.init_device loaded_tokenizer = transformers.AutoTokenizer.from_pretrained( - os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba{batches_per_epoch}'), + os.path.join( + tmp_path, + 'checkpoints', + 'huggingface', + f'ba{batches_per_epoch}', + ), trust_remote_code=True, ) @@ -676,7 +746,8 @@ def _assert_checkpoint_equivalence(tmp_path: pathlib.Path, trainer.state.model.model.to(precision) if fsdp_state_dict_type is not None else trainer.state.model.module.model.to(precision), loaded_model, - just_lora=peft_config is not None) + just_lora=peft_config is not None, + ) check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer) @@ -690,27 +761,34 @@ def _assert_checkpoint_equivalence(tmp_path: pathlib.Path, ('mptmoe', None, None), ('neo', None, None), ('llama2', None, None), - ('llama2', None, { - 'peft_type': 'LORA', - 'task_type': 'CAUSAL_LM', - 'lora_alpha': 32, - 'lora_dropout': 0.05, - 'r': 16, - 'target_modules': [ - 'q_proj', - 'k_proj', - 'v_proj', - ], - }), + ( + 'llama2', + None, + { + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'lora_alpha': 32, + 'lora_dropout': 0.05, + 'r': 16, + 'target_modules': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + }, + ), ], ) @pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', - [('1ba', '1ba', '1ba', 1, 1)]) + [('1ba', '1ba', '1ba', 1, 1)], +) @patch('os.cpu_count', MagicMock(return_value=1)) -@patch('llmfoundry.callbacks.hf_checkpointer.SpawnProcess', - new=MockSpawnProcess) +@patch( + 'llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=MockSpawnProcess, +) def test_huggingface_conversion_callback( model: str, tmp_path: pathlib.Path, @@ -740,11 +818,15 @@ def test_huggingface_conversion_callback( save_folder=os.path.join(tmp_path, 'checkpoints'), save_interval=hf_save_interval, precision=precision_str, - mlflow_registered_model_name='dummy-registered-name') + mlflow_registered_model_name='dummy-registered-name', + ) # Get small version of each model model_cfg, tokenizer_name = _get_model_and_tokenizer( - model, max_seq_len, tie_word_embeddings) + model, + max_seq_len, + tie_word_embeddings, + ) assert model_cfg is not None assert tokenizer_name is not None model_cfg = om.create(model_cfg) @@ -774,11 +856,17 @@ def test_huggingface_conversion_callback( device_batch_size, ) - original_model = build_composer_model(model_cfg['name'], model_cfg, - tokenizer) + original_model = build_composer_model( + model_cfg['name'], + model_cfg, + tokenizer, + ) optimizer_name = optimizer_config.pop('name') - optimizer = build_optimizer(original_model, optimizer_name, - optimizer_config) + optimizer = build_optimizer( + original_model, + optimizer_name, + optimizer_config, + ) mlflow_logger_mock = MagicMock(spec=MLFlowLogger) mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} @@ -807,9 +895,11 @@ def test_huggingface_conversion_callback( # summon full params to check equivalence from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - with FSDP.summon_full_params(trainer.state.model, - writeback=False, - recurse=True): + with FSDP.summon_full_params( + trainer.state.model, + writeback=False, + recurse=True, + ): _assert_checkpoint_equivalence( tmp_path=tmp_path, expected_normal_checkpoints=expected_normal_checkpoints, @@ -821,7 +911,8 @@ def test_huggingface_conversion_callback( model=model, tokenizer=tokenizer, fsdp_state_dict_type=fsdp_state_dict_type, - peft_config=peft_config) + peft_config=peft_config, + ) dist.barrier() delete_transformers_cache() @@ -834,33 +925,40 @@ def test_huggingface_conversion_callback( pytest.param('mptmoe', None, marks=pytest.mark.gpu), ('neo', None), ('llama2', None)], ) -def test_convert_and_generate(model: str, tie_word_embeddings: bool, - tmp_path: pathlib.Path): +def test_convert_and_generate( + model: str, + tie_word_embeddings: bool, + tmp_path: pathlib.Path, +): delete_transformers_cache() om_cfg = None if model == 'mpt': om_cfg = get_config( - conf_path='scripts/train/yamls/pretrain/testing.yaml') + conf_path='scripts/train/yamls/pretrain/testing.yaml', + ) om_cfg['tie_word_embeddings'] = tie_word_embeddings elif model == 'mptmoe': om_cfg = get_config( - conf_path='scripts/train/yamls/pretrain/testing-moe.yaml') + conf_path='scripts/train/yamls/pretrain/testing-moe.yaml', + ) elif model == 'neo': assert tie_word_embeddings is None om_cfg = get_config( - conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml') + conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml', + ) om_cfg['model']['config_overrides']['hidden_size'] = 36 elif model == 'llama2': assert tie_word_embeddings is None if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: pytest.skip( - 'The CI cluster does not have access to the Llama models, so skip this test.' + 'The CI cluster does not have access to the Llama models, so skip this test.', ) om_cfg = get_config( - conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml') - om_cfg['model'][ - 'pretrained_model_name_or_path'] = 'meta-llama/Llama-2-7b-hf' + conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml', + ) + om_cfg['model']['pretrained_model_name_or_path' + ] = 'meta-llama/Llama-2-7b-hf' om_cfg['model']['config_overrides']['num_hidden_layers'] = 2 om_cfg['model']['use_auth_token'] = True om_cfg['tokenizer']['name'] = 'meta-llama/Llama-2-7b-hf' @@ -870,33 +968,44 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool, om_cfg['model']['init_device'] = 'cpu' tokenizer = transformers.AutoTokenizer.from_pretrained( - om_cfg.tokenizer.name, use_auth_token=model == 'llama2') + om_cfg.tokenizer.name, + use_auth_token=model == 'llama2', + ) original_model = build_composer_model( name=om_cfg['model'].name, cfg=om_cfg['model'], tokenizer=tokenizer, ) - trainer = Trainer(model=original_model, - device='cpu' if not model == 'mptmoe' else 'gpu') + trainer = Trainer( + model=original_model, + device='cpu' if not model == 'mptmoe' else 'gpu', + ) trainer.save_checkpoint(os.path.join(tmp_path, 'checkpoint.pt')) - args = Namespace(composer_path=os.path.join(tmp_path, 'checkpoint.pt'), - hf_output_path=os.path.join(tmp_path, 'hf-output-folder'), - output_precision='fp32', - local_checkpoint_save_location=None, - hf_repo_for_upload=None, - trust_remote_code=False, - test_uploaded_model=False) + args = Namespace( + composer_path=os.path.join(tmp_path, 'checkpoint.pt'), + hf_output_path=os.path.join(tmp_path, 'hf-output-folder'), + output_precision='fp32', + local_checkpoint_save_location=None, + hf_repo_for_upload=None, + trust_remote_code=False, + test_uploaded_model=False, + ) convert_composer_to_hf(args) loaded_config = transformers.AutoConfig.from_pretrained( - os.path.join(tmp_path, 'hf-output-folder'), trust_remote_code=True) + os.path.join(tmp_path, 'hf-output-folder'), + trust_remote_code=True, + ) loaded_model = transformers.AutoModelForCausalLM.from_pretrained( os.path.join(tmp_path, 'hf-output-folder'), config=loaded_config, - trust_remote_code=True) + trust_remote_code=True, + ) tokenizer = transformers.AutoTokenizer.from_pretrained( - os.path.join(tmp_path, 'hf-output-folder'), trust_remote_code=True) + os.path.join(tmp_path, 'hf-output-folder'), + trust_remote_code=True, + ) device = 'cuda' if model == 'mptmoe' else 'cpu' precision = torch.bfloat16 if model == 'mptmoe' else torch.float32 @@ -905,32 +1014,44 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool, loaded_model.to(device) loaded_model.to(precision) - output = loaded_model.generate(tokenizer( - 'hello', return_tensors='pt')['input_ids'].to(device), - max_new_tokens=1) + output = loaded_model.generate( + tokenizer('hello', return_tensors='pt')['input_ids'].to(device), + max_new_tokens=1, + ) assert output.shape == (1, 2 + (1 if model == 'llama2' else 0)) - assert sum(p.numel() for p in original_model.model.parameters()) == sum( - p.numel() for p in loaded_model.parameters()) + assert sum(p.numel() for p in original_model.model.parameters() + ) == sum(p.numel() for p in loaded_model.parameters()) assert all( str(type(module1)).split('.')[-1] == str(type(module2)).split('.')[-1] - for module1, module2 in zip(original_model.model.modules(), - loaded_model.modules())) - for p1, p2 in zip(original_model.model.parameters(), - loaded_model.parameters()): + for module1, module2 in + zip(original_model.model.modules(), loaded_model.modules()) + ) + for p1, p2 in zip( + original_model.model.parameters(), + loaded_model.parameters(), + ): assert torch.allclose(p1, p2) delete_transformers_cache() -@pytest.mark.parametrize('conf_path', [ - 'scripts/train/yamls/pretrain/testing.yaml', - pytest.param('scripts/train/yamls/pretrain/testing-moe.yaml', - marks=pytest.mark.gpu), -]) +@pytest.mark.parametrize( + 'conf_path', + [ + 'scripts/train/yamls/pretrain/testing.yaml', + pytest.param( + 'scripts/train/yamls/pretrain/testing-moe.yaml', + marks=pytest.mark.gpu, + ), + ], +) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_convert_and_generate_meta(tie_word_embeddings: str, - tmp_path: pathlib.Path, conf_path: str): +def test_convert_and_generate_meta( + tie_word_embeddings: str, + tmp_path: pathlib.Path, + conf_path: str, +): delete_transformers_cache() from composer.utils import dist @@ -942,44 +1063,52 @@ def test_convert_and_generate_meta(tie_word_embeddings: str, om_cfg['model']['init_device'] = 'cpu' om_cfg['tie_word_embeddings'] = tie_word_embeddings tokenizer = transformers.AutoTokenizer.from_pretrained( - om_cfg.tokenizer.name) + om_cfg.tokenizer.name, + ) original_model = build_composer_model( name=om_cfg['model'].name, cfg=om_cfg['model'], tokenizer=tokenizer, ) - trainer = Trainer(model=original_model, - device='cpu' if not 'moe' in conf_path else 'gpu') + trainer = Trainer( + model=original_model, + device='cpu' if not 'moe' in conf_path else 'gpu', + ) trainer.save_checkpoint(os.path.join(tmp_path_gathered, 'checkpoint.pt')) # patch in the meta device for testing - sd = torch.load(os.path.join(tmp_path_gathered, 'checkpoint.pt'), - map_location='cpu') + sd = torch.load( + os.path.join(tmp_path_gathered, 'checkpoint.pt'), + map_location='cpu', + ) sd['state']['integrations']['huggingface']['model']['config']['content'][ 'init_device'] = 'meta' torch.save(sd, os.path.join(tmp_path_gathered, 'checkpoint.pt')) - args = Namespace(composer_path=os.path.join(tmp_path_gathered, - 'checkpoint.pt'), - hf_output_path=os.path.join(tmp_path_gathered, - 'hf-output-folder'), - output_precision='fp32', - local_checkpoint_save_location=None, - hf_repo_for_upload=None, - trust_remote_code=False, - test_uploaded_model=False) + args = Namespace( + composer_path=os.path.join(tmp_path_gathered, 'checkpoint.pt'), + hf_output_path=os.path.join(tmp_path_gathered, 'hf-output-folder'), + output_precision='fp32', + local_checkpoint_save_location=None, + hf_repo_for_upload=None, + trust_remote_code=False, + test_uploaded_model=False, + ) convert_composer_to_hf(args) loaded_config = transformers.AutoConfig.from_pretrained( os.path.join(tmp_path_gathered, 'hf-output-folder'), - trust_remote_code=True) + trust_remote_code=True, + ) loaded_model = transformers.AutoModelForCausalLM.from_pretrained( os.path.join(tmp_path_gathered, 'hf-output-folder'), config=loaded_config, - trust_remote_code=True) + trust_remote_code=True, + ) tokenizer = transformers.AutoTokenizer.from_pretrained( os.path.join(tmp_path_gathered, 'hf-output-folder'), - trust_remote_code=True) + trust_remote_code=True, + ) device = 'cuda' if 'moe' in conf_path else 'cpu' precision = torch.bfloat16 if 'moe' in conf_path else torch.float32 @@ -988,19 +1117,23 @@ def test_convert_and_generate_meta(tie_word_embeddings: str, loaded_model.to(device) loaded_model.to(precision) - output = loaded_model.generate(tokenizer( - 'hello', return_tensors='pt')['input_ids'].to(device), - max_new_tokens=1) + output = loaded_model.generate( + tokenizer('hello', return_tensors='pt')['input_ids'].to(device), + max_new_tokens=1, + ) assert output.shape == (1, 2) - assert sum(p.numel() for p in original_model.model.parameters()) == sum( - p.numel() for p in loaded_model.parameters()) + assert sum(p.numel() for p in original_model.model.parameters() + ) == sum(p.numel() for p in loaded_model.parameters()) assert all( str(type(module1)).split('.')[-1] == str(type(module2)).split('.')[-1] - for module1, module2 in zip(original_model.model.modules(), - loaded_model.modules())) - for p1, p2 in zip(original_model.model.parameters(), - loaded_model.parameters()): + for module1, module2 in + zip(original_model.model.modules(), loaded_model.modules()) + ) + for p1, p2 in zip( + original_model.model.parameters(), + loaded_model.parameters(), + ): assert torch.allclose(p1, p2) delete_transformers_cache() @@ -1127,7 +1260,7 @@ def test_mptmoe_huggingface_conversion_callback( 'pin_memory': False, 'prefetch_factor': None, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } dataloader_cfg = om.create(dataloader_cfg) @@ -1160,8 +1293,11 @@ def test_mptmoe_huggingface_conversion_callback( init_context=init_context, ) - optimizer = build_optimizer(original_model, optimizer_name, - optimizer_config) + optimizer = build_optimizer( + original_model, + optimizer_name, + optimizer_config, + ) trainer = Trainer( model=original_model, device='gpu', @@ -1183,9 +1319,11 @@ def test_mptmoe_huggingface_conversion_callback( # summon full params to check equivalence from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - with FSDP.summon_full_params(trainer.state.model, - writeback=False, - recurse=True): + with FSDP.summon_full_params( + trainer.state.model, + writeback=False, + recurse=True, + ): loaded_model = None loaded_tokenizer = None # Only rank zero is saving the huggingface checkpoints, so only check @@ -1196,10 +1334,11 @@ def test_mptmoe_huggingface_conversion_callback( for name in os.listdir(os.path.join(tmp_path, 'checkpoints')) if name != 'huggingface' ] - huggingface_checkpoints = [ - name for name in os.listdir( - os.path.join(tmp_path, 'checkpoints', 'huggingface')) - ] + huggingface_checkpoints = list( + os.listdir( + os.path.join(tmp_path, 'checkpoints', 'huggingface'), + ), + ) assert len(normal_checkpoints) == expected_normal_checkpoints assert len(huggingface_checkpoints) == expected_hf_checkpoints @@ -1208,8 +1347,12 @@ def test_mptmoe_huggingface_conversion_callback( with patch.dict('sys.modules', {'flash_attn': None}): # Load the last huggingface checkpoint loaded_model = transformers.AutoModelForCausalLM.from_pretrained( - os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba1'), + os.path.join( + tmp_path, + 'checkpoints', + 'huggingface', + f'ba1', + ), trust_remote_code=True, ) @@ -1219,16 +1362,22 @@ def test_mptmoe_huggingface_conversion_callback( loaded_model.config.torch_dtype = original_model.model.config.torch_dtype loaded_tokenizer = transformers.AutoTokenizer.from_pretrained( - os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba{batches_per_epoch}'), + os.path.join( + tmp_path, + 'checkpoints', + 'huggingface', + f'ba{batches_per_epoch}', + ), trust_remote_code=True, ) for n, p in trainer.state.model.model.named_parameters(): if isinstance(p, DTensor): submodule_name, param_name = '.'.join( - n.split('.')[:-1]), n.split('.')[-1] + n.split('.')[:-1], + ), n.split('.')[-1] submodule = trainer.state.model.model.get_submodule( - submodule_name) + submodule_name, + ) param_tensor = p.full_tensor() param = torch.nn.Parameter(param_tensor) submodule.register_parameter(param_name, param) @@ -1255,9 +1404,12 @@ def test_mptmoe_huggingface_conversion_callback( @pytest.mark.parametrize( 'license_file_name', - ['LICENSE', 'LICENSE.txt', 'license', 'license.md', None]) -def test_license_file_finder(tmp_path: pathlib.Path, - license_file_name: Optional[str]): + ['LICENSE', 'LICENSE.txt', 'license', 'license.md', None], +) +def test_license_file_finder( + tmp_path: pathlib.Path, + license_file_name: Optional[str], +): if license_file_name is not None: with open(os.path.join(tmp_path, license_file_name), 'w') as f: f.write('test') diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index fe58a44459..f721e0499d 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -12,14 +12,20 @@ from omegaconf import OmegaConf as om from scripts.train.train import main, validate_config # noqa: E402 -from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall, - gpt_tiny_cfg) +from tests.data_utils import ( + create_arxiv_dataset, + create_c4_dataset_xxsmall, + gpt_tiny_cfg, +) from tests.fixtures.autouse import REPO_DIR -@pytest.mark.parametrize('averages', [{ - 'core_average': ['language_understanding_lite'] -}, None]) +@pytest.mark.parametrize( + 'averages', + [{ + 'core_average': ['language_understanding_lite'], + }, None], +) def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path): """Test training run with a small dataset.""" dataset_name = create_c4_dataset_xxsmall(tmp_path) @@ -32,8 +38,8 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path): 'scripts/eval/local_data/language_understanding/lambada_openai_small.jsonl', 'num_fewshot': [0], 'icl_task_type': - 'language_modeling' - }) + 'language_modeling', + }), ]) test_cfg.icl_subset_num_batches = 1 test_cfg.eval_subset_num_batches = 2 @@ -61,11 +67,11 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path): DictConfig({ 'name': 'lambada_openai', 'num_fewshot': 0, - 'random_baseline': 0.0 - }) - ]) - }) - ]) + 'random_baseline': 0.0, + }), + ]), + }), + ]), }) if averages is not None: @@ -88,12 +94,16 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path): assert f'icl/metrics/eval_gauntlet/{category_name}' in inmemorylogger.data.keys( ) assert isinstance( - inmemorylogger.data[f'icl/metrics/eval_gauntlet/{category_name}'], list) - assert len(inmemorylogger.data[f'icl/metrics/eval_gauntlet/{category_name}'] - [-1]) > 0 + inmemorylogger.data[f'icl/metrics/eval_gauntlet/{category_name}'], + list, + ) + assert len( + inmemorylogger.data[f'icl/metrics/eval_gauntlet/{category_name}'][-1], + ) > 0 assert isinstance( inmemorylogger.data[f'icl/metrics/eval_gauntlet/{category_name}'][-1], - tuple) + tuple, + ) assert inmemorylogger.data[f'icl/metrics/eval_gauntlet/{category_name}'][ -1][-1] == 0 @@ -129,22 +139,31 @@ def test_train_multi_eval(tmp_path: pathlib.Path): # Checks for first eval dataloader assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list) + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], + list, + ) assert len( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0 + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + ) > 0 assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple) + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + tuple, + ) # Checks for second eval dataloader assert 'metrics/eval/arxiv/LanguageCrossEntropy' in inmemorylogger.data.keys( ) assert isinstance( - inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], list) + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], + list, + ) assert len( - inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1]) > 0 + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1], + ) > 0 assert isinstance( inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1], - tuple) + tuple, + ) @pytest.mark.gpu @@ -158,9 +177,9 @@ def test_validate_config(): test_cfg.model.ffn_config.moe_world_size = 4 test_cfg.fsdp_config.use_orig_params = False with pytest.raises( - ValueError, - match= - 'MoEs with expert parallelism (.*) require `use_orig_params=True`.' + ValueError, + match= + 'MoEs with expert parallelism (.*) require `use_orig_params=True`.', ): validate_config(test_cfg) @@ -186,8 +205,13 @@ def test_eval_metrics_with_no_train_metrics(tmp_path: pathlib.Path): assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list) + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], + list, + ) assert len( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0 + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + ) > 0 assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple) + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + tuple, + ) diff --git a/tests/a_scripts/train/test_train_inputs.py b/tests/a_scripts/train/test_train_inputs.py index 5eb24e05c8..ad36630def 100644 --- a/tests/a_scripts/train/test_train_inputs.py +++ b/tests/a_scripts/train/test_train_inputs.py @@ -35,9 +35,9 @@ def make_fake_index_file(path: str) -> None: 'basename': 'shard.00000.mds.zstd', 'bytes': 564224, 'hashes': {}, - } + }, }], - 'version': 2 + 'version': 2, } if not os.path.exists(path): os.makedirs(os.path.dirname(path), exist_ok=True) @@ -52,7 +52,9 @@ class TestTrainingYAMLInputs: def cfg(self, foundry_dir: str) -> DictConfig: """Create YAML cfg fixture for testing purposes.""" conf_path: str = os.path.join( - foundry_dir, 'scripts/train/yamls/pretrain/testing.yaml') + foundry_dir, + 'scripts/train/yamls/pretrain/testing.yaml', + ) with open(conf_path, 'r', encoding='utf-8') as config: test_cfg = om.load(config) assert isinstance(test_cfg, DictConfig) @@ -80,12 +82,15 @@ def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: for param in mandatory_params: orig_param = cfg.pop(param) with pytest.raises( - (omegaconf.errors.ConfigAttributeError, NameError)): + (omegaconf.errors.ConfigAttributeError, NameError), + ): main(cfg) cfg[param] = orig_param - def test_optional_misspelled_params_raise_warning(self, - cfg: DictConfig) -> None: + def test_optional_misspelled_params_raise_warning( + self, + cfg: DictConfig, + ) -> None: """Check that warnings are raised for optional misspelled parameters.""" optional_params = [ 'save_weights_only', @@ -111,13 +116,17 @@ def test_optional_misspelled_params_raise_warning(self, main(cfg) except: pass - assert any(f'Unused parameter {updated_param} found in cfg.' in - str(warning.message) for warning in warning_list) + assert any( + f'Unused parameter {updated_param} found in cfg.' in + str(warning.message) for warning in warning_list + ) # restore configs. cfg = copy.deepcopy(old_cfg) - def test_extra_params_in_optimizer_cfg_errors(self, - cfg: DictConfig) -> None: + def test_extra_params_in_optimizer_cfg_errors( + self, + cfg: DictConfig, + ) -> None: data_local = './my-copy-c4-opt1' make_fake_index_file(f'{data_local}/train/index.json') make_fake_index_file(f'{data_local}/val/index.json') @@ -127,8 +136,10 @@ def test_extra_params_in_optimizer_cfg_errors(self, with pytest.raises(TypeError): main(cfg) - def test_invalid_name_in_optimizer_cfg_errors(self, - cfg: DictConfig) -> None: + def test_invalid_name_in_optimizer_cfg_errors( + self, + cfg: DictConfig, + ) -> None: data_local = './my-copy-c4-opt2' make_fake_index_file(f'{data_local}/train/index.json') make_fake_index_file(f'{data_local}/val/index.json') @@ -138,22 +149,26 @@ def test_invalid_name_in_optimizer_cfg_errors(self, with pytest.raises(ValueError) as exception_info: main(cfg) assert str(exception_info.value).startswith( - "Cant't find 'invalid-optimizer' in registry llmfoundry -> optimizers." + "Cant't find 'invalid-optimizer' in registry llmfoundry -> optimizers.", ) - def test_extra_params_in_scheduler_cfg_errors(self, - cfg: DictConfig) -> None: + def test_extra_params_in_scheduler_cfg_errors( + self, + cfg: DictConfig, + ) -> None: cfg.scheduler.t_warmup_extra = 'extra-parameter' with pytest.raises(TypeError): main(cfg) - def test_invalid_name_in_scheduler_cfg_errors(self, - cfg: DictConfig) -> None: + def test_invalid_name_in_scheduler_cfg_errors( + self, + cfg: DictConfig, + ) -> None: cfg.scheduler.name = 'invalid-scheduler' with pytest.raises(ValueError) as exception_info: main(cfg) assert str(exception_info.value).startswith( - "Cant't find 'invalid-scheduler' in registry llmfoundry -> schedulers." + "Cant't find 'invalid-scheduler' in registry llmfoundry -> schedulers.", ) def test_no_label_multiple_eval_datasets(self, cfg: DictConfig) -> None: @@ -172,6 +187,6 @@ def test_no_label_multiple_eval_datasets(self, cfg: DictConfig) -> None: with pytest.raises(ValueError) as exception_info: main(cfg) assert str( - exception_info.value + exception_info.value, ) == 'When specifying multiple evaluation datasets, each one must include the \ `label` attribute.' diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index a95dbd7029..e0794df08f 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -8,11 +8,13 @@ import pytest from composer.core import Time, Timestamp, TimeUnit -from llmfoundry.callbacks.async_eval_callback import (AsyncEval, - get_eval_parameters, - get_run_name, - validate_eval_run_config, - validate_interval) +from llmfoundry.callbacks.async_eval_callback import ( + AsyncEval, + get_eval_parameters, + get_run_name, + validate_eval_run_config, + validate_interval, +) from llmfoundry.utils.builders import build_callback from mcli import Run, RunConfig, RunStatus @@ -27,9 +29,9 @@ 'name': 'model_example', 'config_overrides': { 'attn_config': { - 'foo': 'bar' - } - } + 'foo': 'bar', + }, + }, }, 'tokenizer': { 'tokenizer_example': 'tokenizer_example', @@ -48,40 +50,47 @@ def test_get_run_name(): @pytest.fixture(autouse=True, scope='module') def set_os_env_vars(): - with patch.dict('os.environ', { + with patch.dict( + 'os.environ', + { 'MOSAICML_PLATFORM': 'true', - 'RUN_NAME': RUN_NAME - }): + 'RUN_NAME': RUN_NAME, + }, + ): yield def test_fails_when_not_on_platform(): with patch.dict('os.environ', {'MOSAICML_PLATFORM': 'false'}): with pytest.raises( - Exception, - match= - 'AsyncEval callback is only supported when running on the MosaicML platform' + Exception, + match= + 'AsyncEval callback is only supported when running on the MosaicML platform', ): AsyncEval(BASIC_PARAMS, interval='2ba') def test_fails_when_no_run_name(): - with patch.dict('os.environ', { + with patch.dict( + 'os.environ', + { 'MOSAICML_PLATFORM': 'true', - 'RUN_NAME': '' - }): + 'RUN_NAME': '', + }, + ): with pytest.raises( - Exception, - match= - 'RUN_NAME environment variable must be set to use the AsyncEval callback' + Exception, + match= + 'RUN_NAME environment variable must be set to use the AsyncEval callback', ): AsyncEval(BASIC_PARAMS, interval='2ba') def test_get_eval_parameters(): with pytest.raises( - Exception, - match='Missing the following required parameters for async eval:'): + Exception, + match='Missing the following required parameters for async eval:', + ): get_eval_parameters({}, 'checkpoints/file', RUN_NAME) # minimal example @@ -99,12 +108,12 @@ def test_get_eval_parameters(): 'name': 'model_example', 'config_overrides': { 'attn_config': { - 'foo': 'bar' + 'foo': 'bar', }, }, }, 'tokenizer': { - 'tokenizer_example': 'tokenizer_example' + 'tokenizer_example': 'tokenizer_example', }, 'load_path': 'checkpoints/file', }], @@ -119,15 +128,15 @@ def test_get_eval_parameters(): 'dist_timeout': 1, 'eval_gauntlet': 'eval_gauntlet_example', 'fsdp_config': { - 'fsdp_cfg_example': 'fsdp_cfg_example' + 'fsdp_cfg_example': 'fsdp_cfg_example', }, 'icl_subset_num_batches': 4, 'loggers': { 'wandb': { 'init_kwargs': { - 'fee': 'bee' - } - } + 'fee': 'bee', + }, + }, }, 'precision': 'precision_example', 'python_log_level': 'debug', @@ -149,27 +158,27 @@ def test_get_eval_parameters(): 'name': 'model_example', 'config_overrides': { 'attn_config': { - 'foo': 'bar' + 'foo': 'bar', }, }, }, 'tokenizer': { - 'tokenizer_example': 'tokenizer_example' + 'tokenizer_example': 'tokenizer_example', }, 'load_path': 'checkpoints/file', }], 'eval_gauntlet': 'eval_gauntlet_example', 'fsdp_config': { - 'fsdp_cfg_example': 'fsdp_cfg_example' + 'fsdp_cfg_example': 'fsdp_cfg_example', }, 'icl_subset_num_batches': 4, 'loggers': { 'wandb': { 'group': 'foo_bar-1234', 'init_kwargs': { - 'fee': 'bee' + 'fee': 'bee', }, - } + }, }, 'precision': 'precision_example', 'python_log_level': 'debug', @@ -243,8 +252,10 @@ def test_validate_eval_run_config(): ) -@patch('llmfoundry.callbacks.async_eval_callback.get_run', - return_value=FAKE_RUN) +@patch( + 'llmfoundry.callbacks.async_eval_callback.get_run', + return_value=FAKE_RUN, +) def test_async_eval_callback_builds(mock_get_run: MagicMock): kwargs = {'interval': 1} config = { @@ -265,12 +276,18 @@ def test_async_eval_callback_builds(mock_get_run: MagicMock): assert mock_get_run.call_args[0][0] == RUN_NAME -@patch('llmfoundry.callbacks.async_eval_callback.get_run', - return_value=FAKE_RUN) -@patch('llmfoundry.callbacks.async_eval_callback.create_run', - return_value=FAKE_RUN) -def test_async_eval_callback_minimal(mock_create_run: MagicMock, - mock_get_run: MagicMock): +@patch( + 'llmfoundry.callbacks.async_eval_callback.get_run', + return_value=FAKE_RUN, +) +@patch( + 'llmfoundry.callbacks.async_eval_callback.create_run', + return_value=FAKE_RUN, +) +def test_async_eval_callback_minimal( + mock_create_run: MagicMock, + mock_get_run: MagicMock, +): callback = AsyncEval( BASIC_PARAMS, interval='2ba', @@ -324,20 +341,22 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, 'name': 'model_example', 'config_overrides': { 'attn_config': { - 'foo': 'bar' + 'foo': 'bar', }, }, }, 'tokenizer': { - 'tokenizer_example': 'tokenizer_example' + 'tokenizer_example': 'tokenizer_example', }, 'load_path': 'checkpoint/path', }] assert parameters['run_name'] == 'eval-1ba-foo_bar' # original run -@patch('llmfoundry.callbacks.async_eval_callback.get_run', - return_value=FAKE_RUN) +@patch( + 'llmfoundry.callbacks.async_eval_callback.get_run', + return_value=FAKE_RUN, +) def test_async_eval_state(mock_create_run: MagicMock): callback = AsyncEval(BASIC_PARAMS, interval='2ba') @@ -387,23 +406,31 @@ def test_async_eval_state(mock_create_run: MagicMock): FAKE_RUN_WITH_INTEGRATIONS = deepcopy(FAKE_RUN) FAKE_RUN_WITH_INTEGRATIONS.submitted_config.integrations = [ - INTEGRATION_GIT_LLMFOUNDRY, INTEGRATION_GIT_RANDOM + INTEGRATION_GIT_LLMFOUNDRY, + INTEGRATION_GIT_RANDOM, ] -@patch('llmfoundry.callbacks.async_eval_callback.get_run', - return_value=FAKE_RUN_WITH_INTEGRATIONS) -@patch('llmfoundry.callbacks.async_eval_callback.create_run', - return_value=FAKE_RUN_WITH_INTEGRATIONS) -def test_async_eval_callback_integrations(mock_create_run: MagicMock, - mock_get_run: MagicMock): +@patch( + 'llmfoundry.callbacks.async_eval_callback.get_run', + return_value=FAKE_RUN_WITH_INTEGRATIONS, +) +@patch( + 'llmfoundry.callbacks.async_eval_callback.create_run', + return_value=FAKE_RUN_WITH_INTEGRATIONS, +) +def test_async_eval_callback_integrations( + mock_create_run: MagicMock, + mock_get_run: MagicMock, +): callback = AsyncEval( BASIC_PARAMS, interval='2ba', eval_run_config={'compute': { 'cluster': 'c2z3', 'nodes': 2, - }}) + }}, + ) assert mock_get_run.call_count == 1 callback.launch_run('checkpoint/path', Time(1, TimeUnit.BATCH)) @@ -419,8 +446,10 @@ def test_async_eval_callback_integrations(mock_create_run: MagicMock, assert f'cd {custom_path}/scripts' in run_config_created.command -@patch('llmfoundry.callbacks.async_eval_callback.dist.get_world_size', - return_value=4) +@patch( + 'llmfoundry.callbacks.async_eval_callback.dist.get_world_size', + return_value=4, +) def test_get_ready_sharded_checkpoints(mocked_get_world_size: MagicMock): assert not AsyncEval._get_ready_sharded_checkpoints({}, []) assert not AsyncEval._get_ready_sharded_checkpoints( diff --git a/tests/callbacks/test_curriculum_learning_callback.py b/tests/callbacks/test_curriculum_learning_callback.py index 737ec75d26..bbdbf3d691 100644 --- a/tests/callbacks/test_curriculum_learning_callback.py +++ b/tests/callbacks/test_curriculum_learning_callback.py @@ -6,7 +6,9 @@ def test_curriculum_learning_callback_builds(): kwargs = {'dataset_index': 0} - callback = build_callback('curriculum_learning', - kwargs=kwargs, - train_config={'train_loader': {}}) + callback = build_callback( + 'curriculum_learning', + kwargs=kwargs, + train_config={'train_loader': {}}, + ) assert callback is not None diff --git a/tests/callbacks/test_eval_gauntlet_callback.py b/tests/callbacks/test_eval_gauntlet_callback.py index 8d9938e3a1..8d4df43d63 100644 --- a/tests/callbacks/test_eval_gauntlet_callback.py +++ b/tests/callbacks/test_eval_gauntlet_callback.py @@ -34,13 +34,12 @@ def __init__(self, logger_keys: List[str], accuracy: float = 0.25) -> None: for key in logger_keys: dl_name = '/'.join(key.split('/')[1:-1]) self.eval_metrics[dl_name] = {} - self.eval_metrics[dl_name][ - 'InContextLearningLMAccuracy'] = InContextLearningLMAccuracy() - self.eval_metrics[dl_name][ - 'InContextLearningLMAccuracy'].correct = torch.tensor(accuracy * - 100) - self.eval_metrics[dl_name][ - 'InContextLearningLMAccuracy'].total = torch.tensor(100) + self.eval_metrics[dl_name]['InContextLearningLMAccuracy' + ] = InContextLearningLMAccuracy() + self.eval_metrics[dl_name]['InContextLearningLMAccuracy' + ].correct = torch.tensor(accuracy * 100) + self.eval_metrics[dl_name]['InContextLearningLMAccuracy' + ].total = torch.tensor(100) class MockLogger(Logger): @@ -53,11 +52,15 @@ def log_metrics(self, metrics: Dict[str, float]) -> None: self.inmemorylogger.log_metrics(metrics) -@pytest.mark.parametrize('averages', [{ - 'core_average': ['world_knowledge', 'language_understanding'] -}, None]) +@pytest.mark.parametrize( + 'averages', + [{ + 'core_average': ['world_knowledge', 'language_understanding'], + }, None], +) def test_gauntlet_callback(averages: Optional[dict]): - icl_task_config = om.OmegaConf.create(""" + icl_task_config = om.OmegaConf.create( + """ - label: jeopardy_small dataset_uri: eval/local_data/world_knowledge/jeopardy_small.jsonl # ADD YOUR OWN DATASET URI num_fewshot: [10] @@ -68,11 +71,13 @@ def test_gauntlet_callback(averages: Optional[dict]): dataset_uri: eval/local_data/language_understanding/lambada_openai_small.jsonl # ADD YOUR OWN DATASET URI num_fewshot: [0] icl_task_type: language_modeling - """) - assert isinstance(icl_task_config, om.ListConfig) or isinstance( - icl_task_config, str) + """, + ) + assert isinstance(icl_task_config, + om.ListConfig) or isinstance(icl_task_config, str) - eval_gauntlet_config = om.OmegaConf.create(""" + eval_gauntlet_config = om.OmegaConf.create( + """ weighting: EQUAL subtract_random_baseline: true rescale_accuracy: true @@ -87,9 +92,10 @@ def test_gauntlet_callback(averages: Optional[dict]): - name: lambada_openai_small num_fewshot: 0 random_baseline: 0.0 - """) - assert isinstance(eval_gauntlet_config, om.DictConfig) or isinstance( - eval_gauntlet_config, str) + """, + ) + assert isinstance(eval_gauntlet_config, + om.DictConfig) or isinstance(eval_gauntlet_config, str) if averages is not None: eval_gauntlet_config.averages = averages @@ -97,7 +103,13 @@ def test_gauntlet_callback(averages: Optional[dict]): # test loading functionality _, _, eval_gauntlet_callback = build_icl_data_and_gauntlet( - icl_task_config, eval_gauntlet_config, tokenizer, 4, 1024, 1) + icl_task_config, + eval_gauntlet_config, + tokenizer, + 4, + 1024, + 1, + ) assert eval_gauntlet_callback is not None state = MockState(eval_gauntlet_callback.logger_keys) logger = MockLogger(state) @@ -106,15 +118,15 @@ def test_gauntlet_callback(averages: Optional[dict]): result = eval_gauntlet_callback.eval_after_all(state, logger) for category in [ - 'world_knowledge', - 'language_understanding', + 'world_knowledge', + 'language_understanding', ]: name = f'icl/metrics/eval_gauntlet/{category}' assert result[name] == pytest.approx(0.25) if averages is None: - assert result[ - 'icl/metrics/eval_gauntlet/default_average'] == pytest.approx(0.25) + assert result['icl/metrics/eval_gauntlet/default_average' + ] == pytest.approx(0.25) else: - assert result[ - 'icl/metrics/eval_gauntlet/core_average'] == pytest.approx(0.25) + assert result['icl/metrics/eval_gauntlet/core_average' + ] == pytest.approx(0.25) diff --git a/tests/conftest.py b/tests/conftest.py index 545dc7e38f..b099a88cd1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,10 +24,12 @@ ] -def _add_option(parser: pytest.Parser, - name: str, - help: str, - choices: Optional[List[str]] = None): +def _add_option( + parser: pytest.Parser, + name: str, + help: str, + choices: Optional[List[str]] = None, +): parser.addoption( f'--{name}', default=None, @@ -44,11 +46,13 @@ def _add_option(parser: pytest.Parser, def pytest_addoption(parser: pytest.Parser) -> None: - _add_option(parser, - 'seed', - help="""\ + _add_option( + parser, + 'seed', + help="""\ Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked - before each test.""") + before each test.""", + ) def _get_world_size(item: pytest.Item): @@ -57,8 +61,10 @@ def _get_world_size(item: pytest.Item): return item.get_closest_marker('world_size', default=_default).args[0] -def pytest_collection_modifyitems(config: pytest.Config, - items: List[pytest.Item]) -> None: +def pytest_collection_modifyitems( + config: pytest.Config, + items: List[pytest.Item], +) -> None: """Filter tests by world_size (for multi-GPU tests)""" world_size = int(os.environ.get('WORLD_SIZE', '1')) @@ -70,7 +76,7 @@ def pytest_collection_modifyitems(config: pytest.Config, remaining = [] deselected = [] for item in items: - if all([condition(item) for condition in conditions]): + if all(condition(item) for condition in conditions): remaining.append(item) else: deselected.append(item) diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 45cba52063..9eed7f26e7 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -23,33 +23,43 @@ from streaming.base.util import clean_stale_shared_memory from llmfoundry.data import build_dataloader, build_finetuning_dataloader -from llmfoundry.data.finetuning.collator import (_HF_IGNORE_INDEX, - validate_target_settings) -from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH, - SUPPORTED_EXTENSIONS, - dataset_constructor, - is_valid_ift_example, - tokenize_formatted_example) -from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, - build_text_dataloader, - get_tokens_per_batch_func) +from llmfoundry.data.finetuning.collator import ( + _HF_IGNORE_INDEX, + validate_target_settings, +) +from llmfoundry.data.finetuning.tasks import ( + DOWNLOADED_FT_DATASETS_DIRPATH, + SUPPORTED_EXTENSIONS, + dataset_constructor, + is_valid_ift_example, + tokenize_formatted_example, +) +from llmfoundry.data.text_data import ( + ConcatenatedSequenceCollatorWrapper, + build_text_dataloader, + get_tokens_per_batch_func, +) from llmfoundry.utils.builders import build_tokenizer # yapf: disable -from llmfoundry.utils.exceptions import (ConsecutiveRepeatedChatRolesError, - IncorrectMessageKeyQuantityError, - InvalidContentTypeError, - InvalidLastChatMessageRoleError, - InvalidPromptTypeError, - InvalidResponseTypeError, - InvalidRoleError, - MisconfiguredHfDatasetError, - NotEnoughDatasetSamplesError, - UnknownExampleTypeError) +from llmfoundry.utils.exceptions import ( + ConsecutiveRepeatedChatRolesError, + IncorrectMessageKeyQuantityError, + InvalidContentTypeError, + InvalidLastChatMessageRoleError, + InvalidPromptTypeError, + InvalidResponseTypeError, + InvalidRoleError, + MisconfiguredHfDatasetError, + NotEnoughDatasetSamplesError, + UnknownExampleTypeError, +) # yapf: enable from scripts.data_prep.convert_dataset_hf import main as main_hf from scripts.data_prep.convert_finetuning_dataset import get_columns_and_format -from tests.data_utils import (make_tiny_conversation_ft_dataset, - make_tiny_ft_dataset) +from tests.data_utils import ( + make_tiny_conversation_ft_dataset, + make_tiny_ft_dataset, +) from tests.test_utils import generate_exclusive_test_params @@ -69,29 +79,30 @@ def get_abs_data_path(data_local: str): def build_mock_ft_streaming_dataset( - data_path: str, - split: str, - pretokenize: bool, - backwards_compatibility_mode: bool, - use_bytes: bool, - tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None): + data_path: str, + split: str, + pretokenize: bool, + backwards_compatibility_mode: bool, + use_bytes: bool, + tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None, +): dataset = [{ 'prompt': 'This is just a test1', - 'response': 'Hello World1' + 'response': 'Hello World1', }, { 'prompt': 'This is just a test2', - 'response': 'Hello world2' + 'response': 'Hello world2', }, { 'prompt': 'This is just a test3', - 'response': 'Hello world3' + 'response': 'Hello world3', }] output_path = os.path.join(data_path, split) if use_bytes and not backwards_compatibility_mode: raise ValueError( - 'use_bytes should only be true when using backwards_compatibility_mode' + 'use_bytes should only be true when using backwards_compatibility_mode', ) # This is the old code-path, which we want to maintain test coverage of @@ -103,37 +114,51 @@ def build_mock_ft_streaming_dataset( else: columns = { 'input_ids': 'ndarray:uint32', - 'labels': 'ndarray:uint32' + 'labels': 'ndarray:uint32', } else: columns = {'prompt': 'str', 'response': 'str'} - with MDSWriter(columns=columns, out=output_path, - compression=None) as output_writer: + with MDSWriter( + columns=columns, + out=output_path, + compression=None, + ) as output_writer: for sample in dataset: if pretokenize: - sample = tokenize_formatted_example(sample, - tokenizer=tokenizer) + sample = tokenize_formatted_example( + sample, + tokenizer=tokenizer, + ) # Unpack the first turn to account for changes in `tokenize_formatted_example` sample = sample['turns'][0] sample_to_write = {} for key in columns.keys(): if use_bytes: sample_to_write[key] = np.asarray( - sample[key]).tobytes() + sample[key], + ).tobytes() else: - sample_to_write[key] = np.asarray(sample[key], - dtype=np.uint32) + sample_to_write[key] = np.asarray( + sample[key], + dtype=np.uint32, + ) output_writer.write(sample_to_write) else: output_writer.write(sample) return - columns, data_format = get_columns_and_format(dataset, pretokenize, - lambda x: x) + columns, data_format = get_columns_and_format( + dataset, + pretokenize, + lambda x: x, + ) - with MDSWriter(columns=columns, out=output_path, - compression=None) as output_writer: + with MDSWriter( + columns=columns, + out=output_path, + compression=None, + ) as output_writer: for sample in dataset: if pretokenize: sample = tokenize_formatted_example(sample, tokenizer=tokenizer) @@ -157,9 +182,11 @@ def build_mock_ft_streaming_dataset( @pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m']) @pytest.mark.parametrize('pretokenize', [False, True]) -def test_correct_padding(tokenizer_name: str, - pretokenize: bool, - batch_size: int = 4): +def test_correct_padding( + tokenizer_name: str, + pretokenize: bool, + batch_size: int = 4, +): if tokenizer_name == 'gpt2' and not pretokenize: pytest.xfail('Must pretokenize data if using "gpt2" tokenizer') @@ -189,8 +216,10 @@ def test_correct_padding(tokenizer_name: str, 'bos_text': bos_text, 'eos_text': eos_text, 'no_wrap': False, - 'num_workers': None - })) + 'num_workers': None, + }, + ), + ) else: main_hf( Namespace( @@ -207,12 +236,15 @@ def test_correct_padding(tokenizer_name: str, 'eos_text': eos_text, 'no_wrap': False, 'num_workers': None, - })) + }, + ), + ) if not os.path.isdir(path): raise RuntimeError(f'c4 dataset at {path} not set up as expected') test_cfg = get_config( - conf_path='scripts/train/yamls/pretrain/mpt-125m.yaml') + conf_path='scripts/train/yamls/pretrain/mpt-125m.yaml', + ) test_cfg.data_local = data_local test_cfg.eval_loader.dataset.split = split test_cfg.dataset = om.create({ @@ -238,7 +270,9 @@ def test_correct_padding(tokenizer_name: str, # we follow the convention (from huggingface) that non-attended tokens are 0 in the attn mask and -100 in the labels attention_mask = batch.get( - 'attention_mask', torch.ones_like(batch['input_ids'], dtype=torch.bool)) + 'attention_mask', + torch.ones_like(batch['input_ids'], dtype=torch.bool), + ) a = attention_mask == 0 b = batch['labels'] == -100 assert torch.equal(a, b) @@ -247,8 +281,10 @@ def test_correct_padding(tokenizer_name: str, @pytest.mark.parametrize(('eos_token_id', 'bos_token_id'), [(5, None), (None, 5), pytest.param(5, 5, marks=pytest.mark.xfail)]) -def test_sequence_id_wrapper(eos_token_id: Optional[int], - bos_token_id: Optional[int]): +def test_sequence_id_wrapper( + eos_token_id: Optional[int], + bos_token_id: Optional[int], +): wrapper = ConcatenatedSequenceCollatorWrapper( lambda x: x, # placeholder eos_token_id=eos_token_id, @@ -259,11 +295,15 @@ def test_sequence_id_wrapper(eos_token_id: Optional[int], sequence_id = wrapper.get_sequence_id_from_batch(batch) if eos_token_id is not None: - assert torch.equal(sequence_id, - torch.Tensor([[0, 0, 0, 0, 1, 1, 1, 2, 2]])) + assert torch.equal( + sequence_id, + torch.Tensor([[0, 0, 0, 0, 1, 1, 1, 2, 2]]), + ) elif bos_token_id is not None: - assert torch.equal(sequence_id, - torch.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2]])) + assert torch.equal( + sequence_id, + torch.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2]]), + ) else: raise NotImplementedError() @@ -289,14 +329,15 @@ def test_invalid_jsonl_data(): 'pin_memory': False, 'prefetch_factor': None, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } cfg = om.create(cfg) tokenizer = build_tokenizer( tokenizer_name='gpt2', - tokenizer_kwargs={'model_max_length': max_seq_len}) + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) device_batch_size = 2 @@ -305,19 +346,23 @@ def test_invalid_jsonl_data(): expected_keys += ['decoder_attention_mask', 'decoder_input_ids'] with pytest.raises(MisconfiguredHfDatasetError): - build_finetuning_dataloader(cfg, tokenizer, - device_batch_size).dataloader + build_finetuning_dataloader( + cfg, + tokenizer, + device_batch_size, + ).dataloader @pytest.mark.parametrize('use_chat_formatting', [True, False]) @pytest.mark.parametrize('decoder_only_format', [True, False]) @pytest.mark.parametrize('allow_pad_trimming', [True, False]) @pytest.mark.parametrize('packing_ratio', [10.0, None, 'auto']) -def test_finetuning_dataloader(use_chat_formatting: bool, - decoder_only_format: bool, - allow_pad_trimming: bool, - packing_ratio: Optional[Union[float, - Literal['auto']]]): +def test_finetuning_dataloader( + use_chat_formatting: bool, + decoder_only_format: bool, + allow_pad_trimming: bool, + packing_ratio: Optional[Union[float, Literal['auto']]], +): if (decoder_only_format is False) and (packing_ratio is not None): pytest.xfail('packing_ratio only supported for decoder-only format.') @@ -348,14 +393,15 @@ def test_finetuning_dataloader(use_chat_formatting: bool, 'pin_memory': False, 'prefetch_factor': None, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } cfg = om.create(cfg) tokenizer = build_tokenizer( tokenizer_name=tokenizer_name, - tokenizer_kwargs={'model_max_length': max_seq_len}) + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) device_batch_size = 2 @@ -363,8 +409,11 @@ def test_finetuning_dataloader(use_chat_formatting: bool, if not decoder_only_format: expected_keys += ['decoder_attention_mask', 'decoder_input_ids'] - loader = build_finetuning_dataloader(cfg, tokenizer, - device_batch_size).dataloader + loader = build_finetuning_dataloader( + cfg, + tokenizer, + device_batch_size, + ).dataloader batch_ix = 0 for batch in loader: for k in expected_keys: @@ -381,10 +430,13 @@ def test_finetuning_dataloader(use_chat_formatting: bool, @pytest.mark.parametrize( 'hf_name, hf_revision, expectation', [('HuggingFaceH4/databricks_dolly_15k', None, does_not_raise()), - ('squad', '5fe18c', pytest.raises(FileNotFoundError))]) -def test_finetuning_dataloader_safe_load(hf_name: str, - hf_revision: Optional[str], - expectation: ContextManager): + ('squad', '5fe18c', pytest.raises(FileNotFoundError))], +) +def test_finetuning_dataloader_safe_load( + hf_name: str, + hf_revision: Optional[str], + expectation: ContextManager, +): # Clear the folder shutil.rmtree(DOWNLOADED_FT_DATASETS_DIRPATH, ignore_errors=True) cfg = DictConfig({ @@ -397,15 +449,15 @@ def test_finetuning_dataloader_safe_load(hf_name: str, 'shuffle': True, 'safe_load': True, 'hf_kwargs': { - 'revision': hf_revision - } + 'revision': hf_revision, + }, }, 'drop_last': False, 'num_workers': 0, 'pin_memory': False, 'prefetch_factor': None, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, }) tokenizer = build_tokenizer('gpt2', {}) @@ -422,7 +474,8 @@ def test_finetuning_dataloader_safe_load(hf_name: str, assert len(downloaded_files) > 0 assert all( Path(file).suffix in SUPPORTED_EXTENSIONS - for file in downloaded_files) + for file in downloaded_files + ) @pytest.mark.world_size(2) @@ -430,9 +483,11 @@ def test_finetuning_dataloader_safe_load(hf_name: str, @pytest.mark.parametrize('dataset_size', [4, 8]) @pytest.mark.parametrize('device_batch_size', [2, 4]) @pytest.mark.parametrize('drop_last', [True, False]) -def test_finetuning_dataloader_small_data(dataset_size: int, - device_batch_size: int, - drop_last: bool): +def test_finetuning_dataloader_small_data( + dataset_size: int, + device_batch_size: int, + drop_last: bool, +): tokenizer_name = 'gpt2' max_seq_len = 2048 tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small') @@ -459,7 +514,7 @@ def test_finetuning_dataloader_small_data(dataset_size: int, 'pin_memory': False, 'prefetch_factor': 2, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } cfg = om.create(cfg) @@ -471,8 +526,10 @@ def test_finetuning_dataloader_small_data(dataset_size: int, error_context = contextlib.nullcontext() if (dist.get_world_size() * device_batch_size > dataset_size) and drop_last: - error_context = pytest.raises(NotEnoughDatasetSamplesError, - match='Your dataset') + error_context = pytest.raises( + NotEnoughDatasetSamplesError, + match='Your dataset', + ) with error_context: _ = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) @@ -486,8 +543,11 @@ def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str): tokenizer_name = 'gpt2' max_seq_len = 2048 tiny_dataset_folder_path = str(tmp_path) - tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'data', - f'{split}-00000-of-00001.jsonl') + tiny_dataset_path = os.path.join( + tiny_dataset_folder_path, + 'data', + f'{split}-00000-of-00001.jsonl', + ) if dist.get_global_rank() == 0: make_tiny_ft_dataset(path=tiny_dataset_path, size=16) @@ -507,7 +567,7 @@ def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str): 'pin_memory': False, 'prefetch_factor': 2, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } cfg = om.create(cfg) @@ -525,7 +585,8 @@ def mock_get_file(path: str, destination: str, overwrite: bool = False): make_tiny_ft_dataset(path=destination, size=16) else: raise FileNotFoundError( - f'Test error in mock_get_file. {path} does not exist.') + f'Test error in mock_get_file. {path} does not exist.', + ) @pytest.mark.parametrize('split', ['train', 'custom', 'custom-dash', 'data']) @@ -549,7 +610,7 @@ def test_finetuning_dataloader_custom_split_remote(split: str): 'pin_memory': False, 'prefetch_factor': 2, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } cfg = om.create(cfg) @@ -560,8 +621,10 @@ def test_finetuning_dataloader_custom_split_remote(split: str): ) # Mock get_file to avoid downloading the file - with patch('llmfoundry.data.finetuning.dataloader.get_file', - wraps=mock_get_file) as f: + with patch( + 'llmfoundry.data.finetuning.dataloader.get_file', + wraps=mock_get_file, + ) as f: _ = build_finetuning_dataloader(cfg, tokenizer, 4) for call in f.call_args_list: path_arg = call.kwargs['path'] @@ -577,11 +640,13 @@ def test_finetuning_dataloader_custom_split_remote(split: str): @pytest.mark.parametrize('use_multiple_streams', [True, False]) @pytest.mark.parametrize(('backwards_compatibility_mode', 'use_bytes'), [[False, False], [True, False], [True, True]]) -def test_finetuning_dataloader_streaming(pretokenize: bool, - use_multiple_streams: bool, - backwards_compatibility_mode: bool, - use_bytes: bool, - tmp_path: pathlib.Path): +def test_finetuning_dataloader_streaming( + pretokenize: bool, + use_multiple_streams: bool, + backwards_compatibility_mode: bool, + use_bytes: bool, + tmp_path: pathlib.Path, +): clean_stale_shared_memory() max_seq_len = 2048 @@ -602,11 +667,12 @@ def test_finetuning_dataloader_streaming(pretokenize: bool, pretokenize, backwards_compatibility_mode=backwards_compatibility_mode, use_bytes=use_bytes, - tokenizer=tokenizer) + tokenizer=tokenizer, + ) streams_config['streams'][f'stream_{i}'] = { 'remote': remote_path, 'local': local_path, - 'split': 'train' + 'split': 'train', } cfg = { @@ -623,7 +689,7 @@ def test_finetuning_dataloader_streaming(pretokenize: bool, 'pin_memory': False, 'prefetch_factor': 2, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } if use_multiple_streams: cfg['dataset'].update(streams_config) @@ -653,76 +719,111 @@ def test_finetuning_dataloader_is_valid_ift_example( if not decoder_only_format: if (target_prompts != 'none') or (target_responses != 'last'): pytest.xfail( - 'Must use "none" and "last" for target prompts and responses if not using decoder_only_format' + 'Must use "none" and "last" for target prompts and responses if not using decoder_only_format', ) # This should pass - validate_target_settings(target_prompts, target_responses, - decoder_only_format) + validate_target_settings( + target_prompts, + target_responses, + decoder_only_format, + ) max_seq_len = 4 valid_example = {'turns': [{'input_ids': [2, 3, 5], 'labels': [8, 9, 7]}]} - assert is_valid_ift_example(max_seq_len, target_prompts, target_responses, - decoder_only_format, valid_example) + assert is_valid_ift_example( + max_seq_len, + target_prompts, + target_responses, + decoder_only_format, + valid_example, + ) maybe_too_long_example = { 'turns': [{ 'input_ids': [2, 3, 5], - 'labels': [8, 9, 7] - }] * 3 + 'labels': [8, 9, 7], + }] * 3, } if any([ - target_responses == 'all', - target_prompts in {'all', 'length>=2'}, - decoder_only_format == False, + target_responses == 'all', + target_prompts in {'all', 'length>=2'}, + decoder_only_format == False, ]): - assert is_valid_ift_example(max_seq_len, target_prompts, - target_responses, decoder_only_format, - maybe_too_long_example) + assert is_valid_ift_example( + max_seq_len, + target_prompts, + target_responses, + decoder_only_format, + maybe_too_long_example, + ) else: - assert not is_valid_ift_example(max_seq_len, target_prompts, - target_responses, decoder_only_format, - maybe_too_long_example) + assert not is_valid_ift_example( + max_seq_len, + target_prompts, + target_responses, + decoder_only_format, + maybe_too_long_example, + ) another_maybe_too_long_example = { 'turns': [{ 'input_ids': [2, 3, 5, 6, 8], - 'labels': [8, 9, 7] - }] + 'labels': [8, 9, 7], + }], } if any([ - target_prompts in {'all', 'length>=2'}, - decoder_only_format == False, + target_prompts in {'all', 'length>=2'}, + decoder_only_format == False, ]): - assert is_valid_ift_example(max_seq_len, target_prompts, - target_responses, decoder_only_format, - another_maybe_too_long_example) + assert is_valid_ift_example( + max_seq_len, + target_prompts, + target_responses, + decoder_only_format, + another_maybe_too_long_example, + ) else: - assert not is_valid_ift_example(max_seq_len, target_prompts, - target_responses, decoder_only_format, - another_maybe_too_long_example) + assert not is_valid_ift_example( + max_seq_len, + target_prompts, + target_responses, + decoder_only_format, + another_maybe_too_long_example, + ) empty_input_example = {'turns': [{'input_ids': [], 'labels': [8, 9, 7]}]} - assert not is_valid_ift_example(max_seq_len, target_prompts, - target_responses, decoder_only_format, - empty_input_example) + assert not is_valid_ift_example( + max_seq_len, + target_prompts, + target_responses, + decoder_only_format, + empty_input_example, + ) empty_labels_example = {'turns': [{'input_ids': [1, 2], 'labels': []}]} - assert not is_valid_ift_example(max_seq_len, target_prompts, - target_responses, decoder_only_format, - empty_labels_example) + assert not is_valid_ift_example( + max_seq_len, + target_prompts, + target_responses, + decoder_only_format, + empty_labels_example, + ) invalid_prompt_response_params = [ - 'add_bad_data_dropped', 'add_invalid_prompt_type', - 'add_invalid_response_type', 'add_unknown_example_type', - 'add_too_many_example_keys' + 'add_bad_data_dropped', + 'add_invalid_prompt_type', + 'add_invalid_response_type', + 'add_unknown_example_type', + 'add_too_many_example_keys', ] @pytest.mark.parametrize( ','.join(invalid_prompt_response_params), - generate_exclusive_test_params(invalid_prompt_response_params)) + generate_exclusive_test_params(invalid_prompt_response_params), +) def test_malformed_data( add_bad_data_dropped: bool, add_invalid_prompt_type: bool, @@ -779,28 +880,39 @@ def test_malformed_data( 'prefetch_factor': None, 'pin_memory': False, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } cfg = om.create(cfg) error_context = contextlib.nullcontext() if add_invalid_prompt_type: - error_context = pytest.raises(InvalidPromptTypeError, - match='Expected prompt to be') + error_context = pytest.raises( + InvalidPromptTypeError, + match='Expected prompt to be', + ) if add_invalid_response_type: - error_context = pytest.raises(InvalidResponseTypeError, - match='Expected response to be') + error_context = pytest.raises( + InvalidResponseTypeError, + match='Expected response to be', + ) if add_unknown_example_type: - error_context = pytest.raises(UnknownExampleTypeError, - match=r'.*Unknown example type') + error_context = pytest.raises( + UnknownExampleTypeError, + match=r'.*Unknown example type', + ) if add_too_many_example_keys: - error_context = pytest.raises(UnknownExampleTypeError, - match=r'.*Unknown example type') + error_context = pytest.raises( + UnknownExampleTypeError, + match=r'.*Unknown example type', + ) with error_context: - dl = build_finetuning_dataloader(cfg, tokenizer, - device_batch_size).dataloader + dl = build_finetuning_dataloader( + cfg, + tokenizer, + device_batch_size, + ).dataloader if not any(invalid_prompt_response_params): # +5 because we added samples with just bos/eos in each of prompt/response @@ -814,20 +926,26 @@ def test_malformed_data( invalid_conversation_params = [ - 'add_invalid_last_chat_message', 'add_invalid_message_key_quantity', - 'add_invalid_content_type', 'add_invalid_role', 'add_not_alternating_roles' + 'add_invalid_last_chat_message', + 'add_invalid_message_key_quantity', + 'add_invalid_content_type', + 'add_invalid_role', + 'add_not_alternating_roles', ] @pytest.mark.parametrize( ','.join(invalid_conversation_params), - generate_exclusive_test_params(invalid_conversation_params)) -def test_malformed_conversation_data(tmp_path: pathlib.Path, - add_invalid_last_chat_message: bool, - add_invalid_message_key_quantity: bool, - add_invalid_content_type: bool, - add_invalid_role: bool, - add_not_alternating_roles: bool): + generate_exclusive_test_params(invalid_conversation_params), +) +def test_malformed_conversation_data( + tmp_path: pathlib.Path, + add_invalid_last_chat_message: bool, + add_invalid_message_key_quantity: bool, + add_invalid_content_type: bool, + add_invalid_role: bool, + add_not_alternating_roles: bool, +): tokenizer_name = 'mosaicml/mpt-7b' max_seq_len = 2048 dataset_size = 5 @@ -872,7 +990,7 @@ def test_malformed_conversation_data(tmp_path: pathlib.Path, 'prefetch_factor': None, 'pin_memory': False, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } cfg = om.create(cfg) @@ -882,25 +1000,38 @@ def test_malformed_conversation_data(tmp_path: pathlib.Path, error_context = contextlib.nullcontext() if add_invalid_last_chat_message: - error_context = pytest.raises(InvalidLastChatMessageRoleError, - match='Invalid last message role:') + error_context = pytest.raises( + InvalidLastChatMessageRoleError, + match='Invalid last message role:', + ) if add_invalid_message_key_quantity: - error_context = pytest.raises(IncorrectMessageKeyQuantityError, - match='Expected 2 keys in message') + error_context = pytest.raises( + IncorrectMessageKeyQuantityError, + match='Expected 2 keys in message', + ) if add_invalid_content_type: - error_context = pytest.raises(InvalidContentTypeError, - match='Expected content to be') + error_context = pytest.raises( + InvalidContentTypeError, + match='Expected content to be', + ) if add_invalid_role: - error_context = pytest.raises(InvalidRoleError, - match='Expected role to be one of') + error_context = pytest.raises( + InvalidRoleError, + match='Expected role to be one of', + ) if add_not_alternating_roles: - error_context = pytest.raises(ConsecutiveRepeatedChatRolesError, - match='Conversation roles must alternate') + error_context = pytest.raises( + ConsecutiveRepeatedChatRolesError, + match='Conversation roles must alternate', + ) with error_context: - build_finetuning_dataloader(cfg, tokenizer, - device_batch_size).dataloader + build_finetuning_dataloader( + cfg, + tokenizer, + device_batch_size, + ).dataloader def test_finetune_dataloader_pure_pad_responses(): @@ -908,14 +1039,15 @@ def test_finetune_dataloader_pure_pad_responses(): @dataset_constructor.register('pad-response') def pad_preprocessing_function( # type: ignore - inp: dict[str, str]) -> dict[str, str]: + inp: dict[str, str], + ) -> dict[str, str]: """Split out prompt/response from text.""" try: prompt, response = inp['text'].split('### Response:') prompt += '### Response:' except Exception as e: raise ValueError( - f"Unable to extract prompt/response from 'text'={inp['text']}" + f"Unable to extract prompt/response from 'text'={inp['text']}", ) from e return {'prompt': prompt, 'response': '|PAD|' * len(response.split())} @@ -937,21 +1069,24 @@ def pad_preprocessing_function( # type: ignore 'pin_memory': False, 'prefetch_factor': None, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, }) tokenizer_name = 'EleutherAI/gpt-neox-20b' tokenizer_kwargs = { 'model_max_length': cfg.dataset.max_seq_len, - 'pad_token': '|PAD|' + 'pad_token': '|PAD|', } tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) assert tokenizer('|PAD|').input_ids[0] == tokenizer.pad_token_id device_batch_size = 1 - dataloader = build_finetuning_dataloader(cfg, tokenizer, - device_batch_size).dataloader + dataloader = build_finetuning_dataloader( + cfg, + tokenizer, + device_batch_size, + ).dataloader # We should be able to iterate through this dataset without crashing for i, batch in enumerate(dataloader): @@ -974,9 +1109,13 @@ def pad_preprocessing_function( # type: ignore @pytest.mark.parametrize('model_max_length', [1024, 2048]) @pytest.mark.parametrize('padding_side', ['left', 'right']) @pytest.mark.parametrize('add_decoder_input_ids', [True, False]) -def test_token_counting_func(pad_token_id: int, batch_size: int, - model_max_length: int, padding_side: str, - add_decoder_input_ids: bool): +def test_token_counting_func( + pad_token_id: int, + batch_size: int, + model_max_length: int, + padding_side: str, + add_decoder_input_ids: bool, +): gptt = transformers.AutoTokenizer.from_pretrained('gpt2') gptt.pad_token_id = pad_token_id gptt.model_max_length = model_max_length @@ -1000,29 +1139,38 @@ def test_token_counting_func(pad_token_id: int, batch_size: int, decoder_expected_token_count += sample_length expected_token_count += sample_length batch_tokenized['decoder_attention_mask'] = gptt( - decoder_batch_strings, padding=True, - return_tensors='pt')['attention_mask'] + decoder_batch_strings, + padding=True, + return_tensors='pt', + )['attention_mask'] token_counting_func = get_tokens_per_batch_func( - decoder_only=not add_decoder_input_ids) + decoder_only=not add_decoder_input_ids, + ) actual_token_count = token_counting_func(batch_tokenized) assert actual_token_count == expected_token_count -@pytest.mark.parametrize('dataloader_type,tensor_input', - [('finetuning-hf', False), - ('finetuning-streaming', False), ('text', True), - ('text', False)]) +@pytest.mark.parametrize( + 'dataloader_type,tensor_input', + [('finetuning-hf', False), ('finetuning-streaming', False), ('text', True), + ('text', False)], +) @pytest.mark.parametrize('pad_token_id', [100, None]) @pytest.mark.parametrize('batch_size', [1, 8]) @pytest.mark.parametrize('model_max_length', [1024]) @pytest.mark.parametrize('padding_side', ['left']) def test_token_counting_func_dataloader_setting( - dataloader_type: str, tensor_input: bool, pad_token_id: Optional[int], - batch_size: int, model_max_length: int, padding_side: str, - monkeypatch: pytest.MonkeyPatch): + dataloader_type: str, + tensor_input: bool, + pad_token_id: Optional[int], + batch_size: int, + model_max_length: int, + padding_side: str, + monkeypatch: pytest.MonkeyPatch, +): gptt = transformers.AutoTokenizer.from_pretrained('gpt2') gptt.pad_token_id = pad_token_id if pad_token_id is not None else gptt.eos_token_id gptt.model_max_length = model_max_length @@ -1032,10 +1180,9 @@ def test_token_counting_func_dataloader_setting( expected_token_count = 0 for _ in range(batch_size): # Get randomly different lengths if we are going to add padding - sample_length = random.randint( - 1, model_max_length // - 4) if (pad_token_id is not None and - not tensor_input) else model_max_length // 4 + sample_length = random.randint(1, model_max_length // 4) if ( + pad_token_id is not None and not tensor_input + ) else model_max_length // 4 batch_strings.append(' '.join(['hello'] * sample_length)) expected_token_count += sample_length @@ -1062,7 +1209,7 @@ def test_token_counting_func_dataloader_setting( 'prefetch_factor': None, 'pin_memory': False, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } if dataloader_type == 'finetuning-hf': @@ -1077,11 +1224,13 @@ def test_token_counting_func_dataloader_setting( 'packing_ratio': None, 'shuffle': True, }, - **common_args + **common_args, }) monkeypatch.setattr( 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf', - lambda *args, **kwargs: []) + lambda *args, + **kwargs: [], + ) dl = build_finetuning_dataloader(cfg, gptt, batch_size) elif dataloader_type == 'finetuning-streaming': cfg = DictConfig({ @@ -1096,11 +1245,13 @@ def test_token_counting_func_dataloader_setting( 'packing_ratio': None, 'shuffle': True, }, - **common_args + **common_args, }) monkeypatch.setattr( 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_streaming', - lambda *args, **kwargs: []) + lambda *args, + **kwargs: [], + ) dl = build_finetuning_dataloader(cfg, gptt, batch_size) elif dataloader_type == 'text': cfg = DictConfig({ @@ -1113,12 +1264,15 @@ def test_token_counting_func_dataloader_setting( 'shuffle': True, 'shuffle_seed': 0, }, - **common_args + **common_args, }) ds_mock = MagicMock() ds_mock.tokenizer = gptt - monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', - lambda *args, **kwargs: ds_mock) + monkeypatch.setattr( + 'llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, + **kwargs: ds_mock, + ) dl = build_text_dataloader(cfg, gptt, batch_size) else: raise NotImplementedError() @@ -1141,18 +1295,24 @@ def test_build_unknown_dataloader(): invalid_conversation_params_sharegpt = [ - 'add_invalid_last_chat_message', 'add_invalid_content_type', - 'add_invalid_role', 'add_not_alternating_roles' + 'add_invalid_last_chat_message', + 'add_invalid_content_type', + 'add_invalid_role', + 'add_not_alternating_roles', ] @pytest.mark.parametrize( ','.join(invalid_conversation_params_sharegpt), - generate_exclusive_test_params(invalid_conversation_params_sharegpt)) -def test_sharegpt_format(tmp_path: pathlib.Path, - add_invalid_last_chat_message: bool, - add_invalid_content_type: bool, add_invalid_role: bool, - add_not_alternating_roles: bool): + generate_exclusive_test_params(invalid_conversation_params_sharegpt), +) +def test_sharegpt_format( + tmp_path: pathlib.Path, + add_invalid_last_chat_message: bool, + add_invalid_content_type: bool, + add_invalid_role: bool, + add_not_alternating_roles: bool, +): tokenizer_name = 'mosaicml/mpt-7b' max_seq_len = 2048 dataset_size = 5 @@ -1199,26 +1359,37 @@ def test_sharegpt_format(tmp_path: pathlib.Path, 'prefetch_factor': None, 'pin_memory': False, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, } cfg = om.create(cfg) error_context = contextlib.nullcontext() if add_invalid_last_chat_message: - error_context = pytest.raises(InvalidLastChatMessageRoleError, - match='Invalid last message role:') + error_context = pytest.raises( + InvalidLastChatMessageRoleError, + match='Invalid last message role:', + ) if add_invalid_content_type: - error_context = pytest.raises(InvalidContentTypeError, - match='Expected content to be') + error_context = pytest.raises( + InvalidContentTypeError, + match='Expected content to be', + ) if add_invalid_role: - error_context = pytest.raises(InvalidRoleError, - match='Expected role to be one of') + error_context = pytest.raises( + InvalidRoleError, + match='Expected role to be one of', + ) if add_not_alternating_roles: - error_context = pytest.raises(ConsecutiveRepeatedChatRolesError, - match='Conversation roles must alternate') + error_context = pytest.raises( + ConsecutiveRepeatedChatRolesError, + match='Conversation roles must alternate', + ) with error_context: - build_finetuning_dataloader(cfg, tokenizer, - device_batch_size).dataloader + build_finetuning_dataloader( + cfg, + tokenizer, + device_batch_size, + ).dataloader diff --git a/tests/data/test_icl_datasets.py b/tests/data/test_icl_datasets.py index 3a730fdf19..ce9fa7a493 100644 --- a/tests/data/test_icl_datasets.py +++ b/tests/data/test_icl_datasets.py @@ -16,15 +16,19 @@ def load_icl_config(conf_path: str = 'tests/data/test_tasks.yaml'): return test_cfg -def run_test(dir: pathlib.Path, - tokenizer: PreTrainedTokenizerBase, - bos_tok: str = ''): +def run_test( + dir: pathlib.Path, + tokenizer: PreTrainedTokenizerBase, + bos_tok: str = '', +): task_cfg = load_icl_config() - evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks, - tokenizer, - 1024, - 8, - destination_dir=str(dir)) + evaluators, _ = build_icl_evaluators( + task_cfg.icl_tasks, + tokenizer, + 1024, + 8, + destination_dir=str(dir), + ) for e in evaluators: batch = next(e.dataloader.dataloader.__iter__()) @@ -34,14 +38,15 @@ def run_test(dir: pathlib.Path, continuation_indices = list(batch['continuation_indices'][0]) full_example = tokenizer.decode(inputs[0:continuation_indices[-1]]) answer = tokenizer.decode( - inputs[continuation_indices[0]:continuation_indices[-1]]) + inputs[continuation_indices[0]:continuation_indices[-1]], + ) else: if tokenizer.pad_token_id is not None: - start_idx = ( - inputs == tokenizer.pad_token_id).tolist().index(False) + start_idx = (inputs == tokenizer.pad_token_id + ).tolist().index(False) else: - start_idx = ( - inputs == tokenizer.eos_token_id).tolist().index(False) + start_idx = (inputs == tokenizer.eos_token_id + ).tolist().index(False) full_example = tokenizer.decode(inputs[start_idx:]) answer = batch['labels'][0][0] @@ -71,10 +76,14 @@ def run_test(dir: pathlib.Path, assert answer == ' feared violence' -@pytest.mark.parametrize('tokenizer_name,bos_token', - [('facebook/opt-6.7b', ''), - ('EleutherAI/gpt-neox-20b', '')]) -def test_icl_task_tokenizer(tmp_path: pathlib.Path, tokenizer_name: str, - bos_token: str): +@pytest.mark.parametrize( + 'tokenizer_name,bos_token', + [('facebook/opt-6.7b', ''), ('EleutherAI/gpt-neox-20b', '')], +) +def test_icl_task_tokenizer( + tmp_path: pathlib.Path, + tokenizer_name: str, + bos_token: str, +): tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) run_test(tmp_path, tokenizer, bos_token) diff --git a/tests/data/test_packing.py b/tests/data/test_packing.py index 963f8e56b6..db95bcad3d 100644 --- a/tests/data/test_packing.py +++ b/tests/data/test_packing.py @@ -36,11 +36,13 @@ def test_packing(): """Tests that packing works for a single batch.""" pad_token_id = 0 max_seq_len = 5 - packer = BinPackCollator(collator=lambda x: x, - target_batch_size=2, - max_seq_len=max_seq_len, - pad_token_id=pad_token_id, - padding_side='right') + packer = BinPackCollator( + collator=lambda x: x, + target_batch_size=2, + max_seq_len=max_seq_len, + pad_token_id=pad_token_id, + padding_side='right', + ) batch = _data_to_batch([ [1], @@ -51,8 +53,10 @@ def test_packing(): packed_samples = packer.pack(batch) - assert torch.equal(packed_samples['input_ids'], - torch.Tensor([[3, 3, 3, 2, 2], [4, 4, 4, 4, 1]])) + assert torch.equal( + packed_samples['input_ids'], + torch.Tensor([[3, 3, 3, 2, 2], [4, 4, 4, 4, 1]]), + ) assert torch.all(packed_samples['attention_mask'] == 1) @@ -60,11 +64,13 @@ def test_packing_with_leftovers(): """Tests that packing handles leftovers and computes waste correctly.""" pad_token_id = 0 max_seq_len = 5 - packer = BinPackCollator(collator=lambda x: x, - target_batch_size=2, - max_seq_len=max_seq_len, - pad_token_id=pad_token_id, - padding_side='right') + packer = BinPackCollator( + collator=lambda x: x, + target_batch_size=2, + max_seq_len=max_seq_len, + pad_token_id=pad_token_id, + padding_side='right', + ) batch = _data_to_batch([ [1], @@ -75,10 +81,14 @@ def test_packing_with_leftovers(): packed_batch = packer.pack(batch) - assert torch.equal(packed_batch['input_ids'], - torch.Tensor([[4, 4, 4, 4, 1], [4, 4, 4, 4, 0]])) - assert torch.equal(packed_batch['attention_mask'], - torch.Tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])) + assert torch.equal( + packed_batch['input_ids'], + torch.Tensor([[4, 4, 4, 4, 1], [4, 4, 4, 4, 0]]), + ) + assert torch.equal( + packed_batch['attention_mask'], + torch.Tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]), + ) # Check leftovers and waste. assert len(packer._leftover_bins) == 1 @@ -91,10 +101,14 @@ def test_packing_with_leftovers(): # Ensure that leftovers are used in the next batch if possible. batch = _data_to_batch([[1]], max_seq_len, pad_token_id) packed_batch = packer.pack(batch) - assert torch.equal(packed_batch['input_ids'], - torch.Tensor([[2, 2, 0, 0, 0], [1, 0, 0, 0, 0]])) - assert torch.equal(packed_batch['attention_mask'], - torch.Tensor([[1, 1, 0, 0, 0], [1, 0, 0, 0, 0]])) + assert torch.equal( + packed_batch['input_ids'], + torch.Tensor([[2, 2, 0, 0, 0], [1, 0, 0, 0, 0]]), + ) + assert torch.equal( + packed_batch['attention_mask'], + torch.Tensor([[1, 1, 0, 0, 0], [1, 0, 0, 0, 0]]), + ) @patch('llmfoundry.data.packing.profile_packing') @@ -108,7 +122,7 @@ def test_auto_packing(profile_packing: Mock): packing_ratio = auto_packing_ratio( dataloader_cfg=DictConfig({'dataset': { - 'max_seq_len': 2048 + 'max_seq_len': 2048, }}), tokenizer=None, device_batch_size=1, @@ -135,7 +149,7 @@ def test_dist_auto_packing(profile_packing: Mock): packing_ratio = auto_packing_ratio( dataloader_cfg=DictConfig({'dataset': { - 'max_seq_len': 2048 + 'max_seq_len': 2048, }}), tokenizer=None, device_batch_size=1, @@ -151,8 +165,10 @@ def patched_packing_ratio(*args: Any, **kwargs: Any): return auto_packing_ratio(*args, **kwargs, num_packing_ratios=4) -@patch('llmfoundry.data.finetuning.dataloader.auto_packing_ratio', - patched_packing_ratio) +@patch( + 'llmfoundry.data.finetuning.dataloader.auto_packing_ratio', + patched_packing_ratio, +) def test_auto_packing_with_streaming_dataloader(tmp_path: Path): columns = {'prompt': 'str', 'response': 'str'} tokenizer = build_tokenizer('gpt2', {}) @@ -167,7 +183,7 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path): 'local': local_dir, 'packing_ratio': 'auto', 'max_seq_len': 200, - 'decoder_only_format': True + 'decoder_only_format': True, }, 'drop_last': False, # Need to test with 0 num_workers because the packing collator object @@ -179,8 +195,11 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path): 'timeout': 0, }) - loader = build_finetuning_dataloader(cfg, tokenizer, - device_batch_size=6).dataloader + loader = build_finetuning_dataloader( + cfg, + tokenizer, + device_batch_size=6, + ).dataloader batch_ix = 0 for _ in loader: @@ -190,8 +209,10 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path): @pytest.mark.parametrize('packing_ratio', ['auto', 2.0]) -@patch('llmfoundry.data.finetuning.dataloader.auto_packing_ratio', - patched_packing_ratio) +@patch( + 'llmfoundry.data.finetuning.dataloader.auto_packing_ratio', + patched_packing_ratio, +) def test_packing_with_dataloader(packing_ratio: Any): """Tests that packing works with a dataloader.""" reproducibility.seed_all(17) @@ -217,8 +238,11 @@ def test_packing_with_dataloader(packing_ratio: Any): 'timeout': 0, }) - loader = build_finetuning_dataloader(cfg, tokenizer, - device_batch_size=6).dataloader + loader = build_finetuning_dataloader( + cfg, + tokenizer, + device_batch_size=6, + ).dataloader assert isinstance(loader, DataLoader) pack_collator = loader.collate_fn diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 79f17b4ace..785a124379 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -6,12 +6,16 @@ import pytest import transformers -from llmfoundry.data.finetuning.tasks import (_slice_chat_formatted_example, - dataset_constructor, - tokenize_formatted_example) +from llmfoundry.data.finetuning.tasks import ( + _slice_chat_formatted_example, + dataset_constructor, + tokenize_formatted_example, +) from llmfoundry.utils.builders import build_tokenizer -from llmfoundry.utils.exceptions import (ALLOWED_PROMPT_KEYS, - ALLOWED_RESPONSE_KEYS) +from llmfoundry.utils.exceptions import ( + ALLOWED_PROMPT_KEYS, + ALLOWED_RESPONSE_KEYS, +) def test_tokenize_chat_example_malformed(): @@ -19,40 +23,44 @@ def test_tokenize_chat_example_malformed(): too_few_messages = { 'messages': [{ 'role': 'assistant', - 'content': 'Hi, User!' - }] + 'content': 'Hi, User!', + }], } ends_with_user_role = { 'messages': [{ 'role': 'user', - 'content': 'Hello GPT!' + 'content': 'Hello GPT!', }, { 'role': 'assistant', - 'content': 'Hi, User!' + 'content': 'Hi, User!', }, { 'role': 'user', - 'content': 'user message not followed by an assistant label' - }] + 'content': 'user message not followed by an assistant label', + }], } no_assistant_message = { 'messages': [{ 'role': 'user', - 'content': 'Hello GPT!' + 'content': 'Hello GPT!', }, { 'role': 'user', - 'content': 'user message not followed by an assistant label' - }] + 'content': 'user message not followed by an assistant label', + }], } wrong_type = {'messages': 'this is not a list of messages'} malformed_chat_examples = [ - too_few_messages, no_content, ends_with_user_role, no_assistant_message, - wrong_type + too_few_messages, + no_content, + ends_with_user_role, + no_assistant_message, + wrong_type, ] my_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {}) for example in malformed_chat_examples: with pytest.raises(Exception): tokenize_formatted_example( - example, my_tokenizer + example, + my_tokenizer, ) # type: ignore (the typing here is supposed to be malformed) @@ -61,31 +69,31 @@ def test_tokenize_chat_example_well_formed(): { 'messages': [{ 'role': 'user', - 'content': 'Hello, GPT' + 'content': 'Hello, GPT', }, { 'role': 'assistant', - 'content': 'this is my response' - }] + 'content': 'this is my response', + }], }, # prompt/response but in chat format { 'messages': [ { 'role': 'user', - 'content': 'Hello, GPT' + 'content': 'Hello, GPT', }, { 'role': 'assistant', - 'content': 'this is my response' + 'content': 'this is my response', }, { 'role': 'user', - 'content': 'Nice to hear that.' + 'content': 'Nice to hear that.', }, { 'role': 'assistant', - 'content': 'multi-way chat works too!' + 'content': 'multi-way chat works too!', }, - ] + ], }, # multi-way chat ] @@ -99,7 +107,7 @@ def test_tokenize_chat_example_well_formed(): <|im_start|>assistant ''', 'response': - 'this is my response<|im_end|>' + 'this is my response<|im_end|>', }], [{ 'prompt': @@ -110,7 +118,7 @@ def test_tokenize_chat_example_well_formed(): <|im_start|>assistant ''', 'response': - 'this is my response<|im_end|>' + 'this is my response<|im_end|>', }, { 'prompt': ''' @@ -119,22 +127,27 @@ def test_tokenize_chat_example_well_formed(): <|im_start|>assistant ''', 'response': - 'multi-way chat works too!<|im_end|>' + 'multi-way chat works too!<|im_end|>', }], ] chat_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {}) assert len(expected) == len( - chat_examples) # if we add a new example, zip shouldn't fail silently + chat_examples, + ) # if we add a new example, zip shouldn't fail silently for chat_example, expected_stringification in zip(chat_examples, expected): templatized_prompt_response_turns = _slice_chat_formatted_example( - chat_example, chat_tokenizer) - tokenized_example = tokenize_formatted_example(chat_example, - chat_tokenizer) + chat_example, + chat_tokenizer, + ) + tokenized_example = tokenize_formatted_example( + chat_example, + chat_tokenizer, + ) for (prompt, response), exp_str, turn in zip( - templatized_prompt_response_turns, - expected_stringification, - tokenized_example['turns'], + templatized_prompt_response_turns, + expected_stringification, + tokenized_example['turns'], ): assert prompt == exp_str['prompt'] assert response == exp_str['response'] @@ -151,12 +164,16 @@ def test_tokenize_instruct_example_malformed(): multiple_allowed_response_keys = { 'prompt': 'prompt', 'response': 'response', - 'completion': 'completion' + 'completion': 'completion', } malformed_prompt_response_examples = [ - no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt, - extra_keys_with_response, multiple_allowed_response_keys + no_keys, + no_prompt_key, + no_response_key, + extra_keys_with_prompt, + extra_keys_with_response, + multiple_allowed_response_keys, ] for example in malformed_prompt_response_examples: @@ -178,58 +195,60 @@ def test_tokenize_instruct_example_well_formed(): @pytest.mark.parametrize( 'tokenizer_name', - ['EleutherAI/gpt-neox-20b', 'HuggingFaceH4/zephyr-7b-beta', 't5-base']) + ['EleutherAI/gpt-neox-20b', 'HuggingFaceH4/zephyr-7b-beta', 't5-base'], +) @pytest.mark.parametrize('messages_format', [True, False]) def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): if messages_format: convo = [ { 'role': 'system', - 'content': 'everyone thinks you are so cool' + 'content': 'everyone thinks you are so cool', }, { 'role': 'user', - 'content': 'hiiii' + 'content': 'hiiii', }, { 'role': 'assistant', - 'content': 'yassss' + 'content': 'yassss', }, { 'role': 'user', - 'content': 'HIIIIII!!!' + 'content': 'HIIIIII!!!', }, { 'role': 'assistant', - 'content': 'YASSSSSS' + 'content': 'YASSSSSS', }, ] else: convo = [ { 'from': 'system', - 'value': 'everyone thinks you are so cool' + 'value': 'everyone thinks you are so cool', }, { 'from': 'human', - 'value': 'hiiii' + 'value': 'hiiii', }, { 'from': 'gpt', - 'value': 'yassss' + 'value': 'yassss', }, { 'from': 'tool', - 'value': 'HIIIIII!!!' + 'value': 'HIIIIII!!!', }, { 'from': 'gpt', - 'value': 'YASSSSSS' + 'value': 'YASSSSSS', }, ] tmp = {'conversations': convo} preprocessor = dataset_constructor.get_preprocessing_fn_from_str( - 'teknium/OpenHermes-2.5') + 'teknium/OpenHermes-2.5', + ) assert preprocessor is not None convo = preprocessor(tmp)['messages'] assert isinstance(convo, list) @@ -239,7 +258,9 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): tok = transformers.AutoTokenizer.from_pretrained(tokenizer_name) templated_prompt_response_turns = _slice_chat_formatted_example( - example, tok) + example, + tok, + ) reconstructed_chat = '' for prompt, response in templated_prompt_response_turns: @@ -252,7 +273,9 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): def test_tokenize_no_labels_bos_pr(): # This tokenizer automatically adds bos tokens tokenizer = transformers.AutoTokenizer.from_pretrained( - 'ai21labs/Jamba-v0.1', add_bos_token=True) + 'ai21labs/Jamba-v0.1', + add_bos_token=True, + ) example = {'prompt': 'prompt', 'response': 'response'} diff --git a/tests/data_utils.py b/tests/data_utils.py index fd24d4cbbf..30f49efb2d 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -36,7 +36,7 @@ def make_tiny_ft_dataset( if add_bad_data_dropped: if pad_token is None: raise ValueError( - 'pad_token, start_token, and end_token must be specified if add_bad_data is True' + 'pad_token, start_token, and end_token must be specified if add_bad_data is True', ) # empty prompt samples.append({'prompt': '', 'response': 'goodbye'}) @@ -47,14 +47,14 @@ def make_tiny_ft_dataset( # prompt just None samples.append({ 'prompt': None, - 'response': 'goodbye' + 'response': 'goodbye', }) # type: ignore (intentional test) if add_invalid_response_type: # response just None samples.append({ 'prompt': 'hello', - 'response': None + 'response': None, }) # type: ignore (intentional test) if add_too_many_example_keys: @@ -62,13 +62,13 @@ def make_tiny_ft_dataset( samples.append({ 'prompt': 'hello', 'response': 'goodbye', - 'completion': 'bar' + 'completion': 'bar', }) if add_just_bos_eos_pad: if pad_token is None or start_token is None or end_token is None: raise ValueError( - 'pad_token, start_token, and end_token must be specified if add_just_bos_eos is True' + 'pad_token, start_token, and end_token must be specified if add_just_bos_eos is True', ) # prompt just start samples.append({'prompt': start_token, 'response': 'goodbye'}) @@ -106,14 +106,14 @@ def make_tiny_conversation_ft_dataset( good_sample = { 'messages': [{ 'role': 'system', - 'content': 'A conversation between a user and a helpful assistant.' + 'content': 'A conversation between a user and a helpful assistant.', }, { 'role': 'user', - 'content': "Hi there. What's the capital of the moon?" + 'content': "Hi there. What's the capital of the moon?", }, { 'role': 'assistant', - 'content': "This question doesn't make sense." - }] + 'content': "This question doesn't make sense.", + }], } samples = [good_sample] * size @@ -125,14 +125,14 @@ def make_tiny_conversation_ft_dataset( 'role': 'system', 'content': - 'A conversation between a user and a helpful assistant.' + 'A conversation between a user and a helpful assistant.', }, { 'role': 'user', - 'content': "Hi there. What's the capital of the moon?" + 'content': "Hi there. What's the capital of the moon?", }, { 'role': 'system', - 'content': "This question doesn't make sense." - }] + 'content': "This question doesn't make sense.", + }], }) if add_invalid_message_key_quantity: @@ -144,8 +144,8 @@ def make_tiny_conversation_ft_dataset( 'content': 'A conversation between a user and a helpful assistant.', 'extra_key': - 'extra value' - }] + 'extra value', + }], }) if add_invalid_role: @@ -155,14 +155,14 @@ def make_tiny_conversation_ft_dataset( 'role': 'system', 'content': - 'A conversation between a user and a helpful assistant.' + 'A conversation between a user and a helpful assistant.', }, { 'role': 'foo', - 'content': "Hi there. What's the capital of the moon?" + 'content': "Hi there. What's the capital of the moon?", }, { 'role': 'assistant', - 'content': "This question doesn't make sense." - }] + 'content': "This question doesn't make sense.", + }], }) if add_invalid_content_type: @@ -172,14 +172,14 @@ def make_tiny_conversation_ft_dataset( 'role': 'system', 'content': - 'A conversation between a user and a helpful assistant.' + 'A conversation between a user and a helpful assistant.', }, { 'role': 'user', - 'content': "Hi there. What's the capital of the moon?" + 'content': "Hi there. What's the capital of the moon?", }, { 'role': 'assistant', - 'content': None - }] + 'content': None, + }], }) # type: ignore (intentional test) if add_not_alternating_roles: @@ -189,14 +189,14 @@ def make_tiny_conversation_ft_dataset( 'role': 'system', 'content': - 'A conversation between a user and a helpful assistant.' + 'A conversation between a user and a helpful assistant.', }, { 'role': 'assistant', - 'content': "Hi there. What's the capital of the moon?" + 'content': "Hi there. What's the capital of the moon?", }, { 'role': 'assistant', - 'content': "This question doesn't make sense." - }] + 'content': "This question doesn't make sense.", + }], }) def messages_to_conversation(sample: Dict): @@ -244,14 +244,18 @@ def create_c4_dataset_xxsmall(path: Path) -> str: 'bos_text': '', 'eos_text': '<|endoftext|>', 'no_wrap': False, - 'num_workers': 8 - })) + 'num_workers': 8, + }, + ), + ) # copy the small downloaded_split to other c4 splits for mocking purposes mocked_splits = ['train', 'val'] for mocked_split in mocked_splits: - shutil.copytree(os.path.join(c4_dir, 'val_xxsmall'), - os.path.join(c4_dir, mocked_split)) + shutil.copytree( + os.path.join(c4_dir, 'val_xxsmall'), + os.path.join(c4_dir, mocked_split), + ) assert os.path.exists(c4_dir) return c4_dir @@ -276,8 +280,10 @@ def create_arxiv_dataset(path: Path) -> str: 'bos_text': None, 'eos_text': None, 'no_wrap': False, - 'num_workers': None - })) + 'num_workers': None, + }, + ), + ) return arxiv_dir diff --git a/tests/eval/test_in_context_learning_datasets.py b/tests/eval/test_in_context_learning_datasets.py index 8ed06b6f70..ea87ed17d0 100644 --- a/tests/eval/test_in_context_learning_datasets.py +++ b/tests/eval/test_in_context_learning_datasets.py @@ -10,37 +10,43 @@ import pytest import torch +import transformers from composer import Evaluator from composer.core import DataSpec -from torch.utils.data import DataLoader - -# isort: off -from llmfoundry.eval.datasets import ( - InContextLearningDataset, InContextLearningCodeEvalDataset, - InContextLearningMultipleChoiceTaskDataset, - InContextLearningGenerationTaskWithAnswersDataset, - InContextLearningSchemaTaskDataset, get_icl_task_dataloader, strip_data, - tokenizer_needs_prefix_space, trim_context, get_continuation_span, - get_fewshot_sample_idxs, make_padded_input) -# isort: on -import transformers from composer.datasets.utils import MultiTokenEOSCriteria from composer.loggers import InMemoryLogger from composer.models import HuggingFaceModel from composer.trainer import Trainer from composer.utils import dist, reproducibility +from torch.utils.data import DataLoader +from llmfoundry.eval.datasets import ( + InContextLearningCodeEvalDataset, + InContextLearningDataset, + InContextLearningGenerationTaskWithAnswersDataset, + InContextLearningMultipleChoiceTaskDataset, + InContextLearningSchemaTaskDataset, + get_continuation_span, + get_fewshot_sample_idxs, + get_icl_task_dataloader, + make_padded_input, + strip_data, + tokenizer_needs_prefix_space, + trim_context, +) from llmfoundry.eval.metrics import ( InContextLearningCodeEvalAccuracy, - InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, - InContextLearningMultipleChoiceAccuracy) + InContextLearningGenerationExactMatchAccuracy, + InContextLearningLMAccuracy, + InContextLearningMultipleChoiceAccuracy, +) def test_strip_data(): data_to_strip = { 'strip_data': ' boo! \n', 'has_space': ' wa hoo!', - 'end_space': 'yoohoo! ' + 'end_space': 'yoohoo! ', } stripped_data = strip_data(data_to_strip) for k, v in stripped_data.items(): @@ -50,16 +56,19 @@ def test_strip_data(): @pytest.mark.skip( - reason="Currently don't have a tokenizer that satisfies this test") + reason="Currently don't have a tokenizer that satisfies this test", +) def test_tokenizer_needs_prefix_space_when_space_not_needed( - tiny_gpt2_tokenizer: transformers.AutoTokenizer): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, +): assert not tokenizer_needs_prefix_space(tiny_gpt2_tokenizer) def test_tokenizer_needs_prefix_space_when_space_needed(): tokenizer = transformers.AutoTokenizer.from_pretrained( 'facebook/opt-125m', - use_fast=False) # type: ignore reportUnboundVariable + use_fast=False, + ) # type: ignore reportUnboundVariable assert tokenizer_needs_prefix_space(tokenizer) @@ -67,9 +76,11 @@ def test_trim_context(): context = [0] * 99 + [1] * 2037 continuation = [2] * 10 max_seq_len = 2048 - trimmed_context = trim_context(context, - continuation, - max_seq_len=max_seq_len) + trimmed_context = trim_context( + context, + continuation, + max_seq_len=max_seq_len, + ) assert len(trimmed_context) == 2038 assert trimmed_context[0] == 0 assert trimmed_context[1] == 1 @@ -98,20 +109,26 @@ def test_get_continuation_span(): @pytest.mark.parametrize('padding_side', ['left', 'right', 'middle']) -def test_make_padding(tiny_gpt2_tokenizer: transformers.AutoTokenizer, - padding_side: str): +def test_make_padding( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + padding_side: str, +): context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids'] padding_id = tiny_gpt2_tokenizer.eos_token_id error_context = contextlib.nullcontext() if padding_side in { - 'left', 'right' + 'left', + 'right', } else pytest.raises(ValueError) with error_context: - input_ids = make_padded_input(context, [], - 2048, - padding_id, - padding_side=padding_side) + input_ids = make_padded_input( + context, + [], + 2048, + padding_id, + padding_side=padding_side, + ) if padding_side == 'left': assert input_ids[0] == tiny_gpt2_tokenizer.eos_token_id @@ -122,34 +139,40 @@ def test_make_padding(tiny_gpt2_tokenizer: transformers.AutoTokenizer, def test_batch_padding_logic_no_padding( - tiny_gpt2_tokenizer: transformers.AutoTokenizer): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, +): continuation = tiny_gpt2_tokenizer(' dog' * 2000)['input_ids'] context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids'] max_seq_len = 2048 trimmed_context = trim_context(context, continuation, max_seq_len) continuation_spans = get_continuation_span(trimmed_context, continuation) - padded_input = make_padded_input(trimmed_context, - continuation, - max_seq_len, - tiny_gpt2_tokenizer.pad_token_id, - padding_side='right') + padded_input = make_padded_input( + trimmed_context, + continuation, + max_seq_len, + tiny_gpt2_tokenizer.pad_token_id, + padding_side='right', + ) assert continuation_spans[0] == 48 and continuation_spans[-1] == 2047 assert len(padded_input) == 2048 assert tiny_gpt2_tokenizer.pad_token_id not in padded_input def test_batch_padding_logic_with_padding( - tiny_gpt2_tokenizer: transformers.AutoTokenizer): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, +): continuation = tiny_gpt2_tokenizer(' dog' * 200)['input_ids'] context = tiny_gpt2_tokenizer(' cat' * 200)['input_ids'] max_seq_len = 2048 trimmed_context = trim_context(context, continuation, max_seq_len) continuation_spans = get_continuation_span(trimmed_context, continuation) - padded_input = make_padded_input(trimmed_context, - continuation, - max_seq_len, - tiny_gpt2_tokenizer.pad_token_id, - padding_side='right') + padded_input = make_padded_input( + trimmed_context, + continuation, + max_seq_len, + tiny_gpt2_tokenizer.pad_token_id, + padding_side='right', + ) assert continuation_spans[0] == 200 and continuation_spans[-1] == 399 assert len(padded_input) == 2048 assert padded_input[-1] == tiny_gpt2_tokenizer.pad_token_id @@ -158,28 +181,36 @@ def test_batch_padding_logic_with_padding( def test_fewshot_sample_idxs(): rng = random.Random(1234) - fewshot_idxs = get_fewshot_sample_idxs(dataset_size=5, - num_fewshot=4, - example_idx=4, - rng=rng) + fewshot_idxs = get_fewshot_sample_idxs( + dataset_size=5, + num_fewshot=4, + example_idx=4, + rng=rng, + ) assert fewshot_idxs == {0, 1, 2, 3} - fewshot_idxs = get_fewshot_sample_idxs(dataset_size=5, - num_fewshot=5, - example_idx=4, - rng=rng) + fewshot_idxs = get_fewshot_sample_idxs( + dataset_size=5, + num_fewshot=5, + example_idx=4, + rng=rng, + ) assert fewshot_idxs == {0, 1, 2, 3} - fewshot_idxs = get_fewshot_sample_idxs(dataset_size=5, - num_fewshot=500, - example_idx=4, - rng=rng) + fewshot_idxs = get_fewshot_sample_idxs( + dataset_size=5, + num_fewshot=500, + example_idx=4, + rng=rng, + ) assert fewshot_idxs == {0, 1, 2, 3} - fewshot_idxs = get_fewshot_sample_idxs(dataset_size=10, - num_fewshot=7, - example_idx=4, - rng=rng) + fewshot_idxs = get_fewshot_sample_idxs( + dataset_size=10, + num_fewshot=7, + example_idx=4, + rng=rng, + ) assert len(fewshot_idxs) == 7 and 4 not in fewshot_idxs @@ -191,32 +222,58 @@ def test_fewshot_sample_idxs_randomness(): rng_2_seed_1234 = random.Random(1234) rng_3_seed_11 = random.Random(11) - rng_1_sample_1 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 1, - rng_1_seed_1234) - rng_2_sample_1 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 1, - rng_2_seed_1234) - rng_3_sample_1 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 1, - rng_3_seed_11) + rng_1_sample_1 = get_fewshot_sample_idxs( + dataset_size, + num_fewshot, + 1, + rng_1_seed_1234, + ) + rng_2_sample_1 = get_fewshot_sample_idxs( + dataset_size, + num_fewshot, + 1, + rng_2_seed_1234, + ) + rng_3_sample_1 = get_fewshot_sample_idxs( + dataset_size, + num_fewshot, + 1, + rng_3_seed_11, + ) assert rng_1_sample_1 == rng_2_sample_1 assert rng_1_sample_1 != rng_3_sample_1 - rng_1_sample_2 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 2, - rng_1_seed_1234) - rng_2_sample_2 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 2, - rng_2_seed_1234) - rng_3_sample_2 = get_fewshot_sample_idxs(dataset_size, num_fewshot, 2, - rng_3_seed_11) + rng_1_sample_2 = get_fewshot_sample_idxs( + dataset_size, + num_fewshot, + 2, + rng_1_seed_1234, + ) + rng_2_sample_2 = get_fewshot_sample_idxs( + dataset_size, + num_fewshot, + 2, + rng_2_seed_1234, + ) + rng_3_sample_2 = get_fewshot_sample_idxs( + dataset_size, + num_fewshot, + 2, + rng_3_seed_11, + ) assert rng_1_sample_2 == rng_2_sample_2 assert rng_1_sample_2 != rng_3_sample_2 @pytest.mark.filterwarnings( - r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning', ) def test_update_generation_kwargs( - tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): tokenizer = tiny_gpt2_tokenizer seqlen = 2048 num_fewshot = 0 @@ -242,58 +299,72 @@ def test_update_generation_kwargs( destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - generation_kwargs=gen_kwargs) + generation_kwargs=gen_kwargs, + ) assert dl.base_batch['generation_kwargs'] == { 'test_arg1': 1, - 'test_arg2': 2 + 'test_arg2': 2, } def test_stop_sequences_criteria( - tiny_gpt2_tokenizer: transformers.AutoTokenizer): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, +): eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) seq1 = tiny_gpt2_tokenizer('Dogs are furry')['input_ids'] seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] seq1 = [tiny_gpt2_tokenizer.pad_token_id] * (len(seq2) - len(seq1)) + seq1 input_ids = torch.LongTensor([seq1, seq2]) - assert not eos_criteria(input_ids, - None) # pyright: ignore[reportGeneralTypeIssues] + assert not eos_criteria( + input_ids, + None, + ) # pyright: ignore[reportGeneralTypeIssues] eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] input_ids = torch.LongTensor([seq1, seq2]) - assert eos_criteria(input_ids, - None) # pyright: ignore[reportGeneralTypeIssues] + assert eos_criteria( + input_ids, + None, + ) # pyright: ignore[reportGeneralTypeIssues] def test_stop_sequences_criteria_sentencepiece( - tiny_llama_tokenizer: transformers.AutoTokenizer): + tiny_llama_tokenizer: transformers.AutoTokenizer, +): tokenizer = tiny_llama_tokenizer eos_criteria = MultiTokenEOSCriteria('\n\n', tokenizer, 2) seq1 = tokenizer( - '\n\nDogs' - )['input_ids'] # check to make sure starting with the stop sequence doesnt break it + '\n\nDogs', + )['input_ids' + ] # check to make sure starting with the stop sequence doesnt break it seq2 = tokenizer('Dogs are furry\n\n')['input_ids'] seq1 = [tokenizer.eos_token_id] * (len(seq2) - len(seq1)) + seq1 input_ids = torch.LongTensor([seq1, seq2]) - assert not eos_criteria(input_ids, - None) # pyright: ignore[reportGeneralTypeIssues] + assert not eos_criteria( + input_ids, + None, + ) # pyright: ignore[reportGeneralTypeIssues] eos_criteria = MultiTokenEOSCriteria('\n\n', tokenizer, 2) seq1 = tokenizer('Dogs are furry\n\n')['input_ids'] seq2 = tokenizer('Dogs are furry\n\n')['input_ids'] input_ids = torch.LongTensor([seq1, seq2]) - assert eos_criteria(input_ids, - None) # pyright: ignore[reportGeneralTypeIssues] + assert eos_criteria( + input_ids, + None, + ) # pyright: ignore[reportGeneralTypeIssues] @pytest.mark.filterwarnings( - r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning', ) def test_update_generation_kwargs_no_kwargs( - tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): tokenizer = tiny_gpt2_tokenizer seqlen = 2048 num_fewshot = 0 @@ -317,7 +388,8 @@ def test_update_generation_kwargs_no_kwargs( continuation_delimiter='\nSpell:', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map) + hf_parsing_map=hf_parsing_map, + ) assert not 'generation_kwargs' in dl.base_batch @@ -326,7 +398,8 @@ def test_update_generation_kwargs_no_kwargs_qa_dataset(tmp_path: Path): dataset_uri = f'{local_data}/triviaqa_small.jsonl' tokenizer = transformers.AutoTokenizer.from_pretrained( - 'facebook/opt-125m') # type: ignore reportUnboundVariable + 'facebook/opt-125m', + ) # type: ignore reportUnboundVariable tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) @@ -341,7 +414,8 @@ def test_update_generation_kwargs_no_kwargs_qa_dataset(tmp_path: Path): example_delimiter='\n', continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), - generation_kwargs=None) + generation_kwargs=None, + ) assert len(dl.base_batch['generation_kwargs']) == 4 @@ -350,7 +424,8 @@ def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path: Path): dataset_uri = f'{local_data}/triviaqa_small.jsonl' tokenizer = transformers.AutoTokenizer.from_pretrained( - 'facebook/opt-125m') # type: ignore reportUnboundVariable + 'facebook/opt-125m', + ) # type: ignore reportUnboundVariable tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) @@ -365,17 +440,20 @@ def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path: Path): example_delimiter='\n', continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), - generation_kwargs={'temperature': 0.9}) + generation_kwargs={'temperature': 0.9}, + ) assert 'generation_kwargs' in dl.base_batch assert dl.base_batch['generation_kwargs']['temperature'] == 0.9 assert len(dl.base_batch['generation_kwargs']) == 5 @pytest.mark.filterwarnings( - r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning', ) -def test_construct_context(tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): +def test_construct_context( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): tokenizer = tiny_gpt2_tokenizer seqlen = 2048 num_fewshot = 0 @@ -399,33 +477,37 @@ def test_construct_context(tiny_gpt2_tokenizer: transformers.AutoTokenizer, continuation_delimiter='\nSpell: ', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map) + hf_parsing_map=hf_parsing_map, + ) constructed_context = dl.construct_context({ 'context': 'quas quas exort', - 'answer': 'ice wall' + 'answer': 'ice wall', }) assert constructed_context == 'Orbs: quas quas exort\nSpell: ' - constructed_context = dl.construct_context( - { - 'context': 'quas quas exort', - 'answer': 'ice wall' - }, add_answer=True) + constructed_context = dl.construct_context({ + 'context': 'quas quas exort', + 'answer': 'ice wall', + }, + add_answer=True) assert constructed_context == 'Orbs: quas quas exort\nSpell: ice wall' constructed_context = dl.construct_context( { 'context': 'quas quas exort', - 'answer': 'ice wall' + 'answer': 'ice wall', }, preceding_text='The harsh White Waste beckons!', - add_answer=True) + add_answer=True, + ) assert constructed_context == '\nOrbs: quas quas exort\nSpell: ice wall' @pytest.mark.filterwarnings( - r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning', ) def test_get_answer_from_example( - tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): tokenizer = tiny_gpt2_tokenizer seqlen = 2048 num_fewshot = 0 @@ -449,21 +531,23 @@ def test_get_answer_from_example( continuation_delimiter='\nSpell:', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map) + hf_parsing_map=hf_parsing_map, + ) answer = dl.get_answer_from_example({ 'context': 'wex exort exort', - 'answer': 'alacrity' + 'answer': 'alacrity', }) assert answer == ' alacrity' @pytest.mark.filterwarnings( - r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning', ) def test_fix_eos_on_preamble(tmp_path: Path): tokenizer = transformers.AutoTokenizer.from_pretrained( 'facebook/opt-125m', - use_fast=False) # type: ignore reportUnboundVariable + use_fast=False, + ) # type: ignore reportUnboundVariable seqlen = 2048 num_fewshot = 0 prompt_string = '' @@ -486,7 +570,8 @@ def test_fix_eos_on_preamble(tmp_path: Path): continuation_delimiter='\nSpell:', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map) + hf_parsing_map=hf_parsing_map, + ) preamble = 'blah blah blah.' tokenized_preamble = tokenizer.encode(preamble) tokenized_preamble += [tokenizer.eos_token_id] @@ -496,10 +581,12 @@ def test_fix_eos_on_preamble(tmp_path: Path): @pytest.mark.filterwarnings( - r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning', ) def test_tokenize_example_with_tokenize_labels( - tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): tokenizer = tiny_gpt2_tokenizer seqlen = 2048 num_fewshot = 0 @@ -524,13 +611,32 @@ def test_tokenize_example_with_tokenize_labels( destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - tokenize_labels=True) - tokenized_example = dl.tokenize_example('What spell does this invoke? ', - 'exort exort wex\nSpell: ', - {'answer': ' Meatball'}) + tokenize_labels=True, + ) + tokenized_example = dl.tokenize_example( + 'What spell does this invoke? ', + 'exort exort wex\nSpell: ', + {'answer': ' Meatball'}, + ) tokenized_input = [ - 2061, 4822, 857, 428, 26342, 30, 220, 1069, 419, 409, 419, 356, 87, 198, - 31221, 25, 19145, 1894 + 2061, + 4822, + 857, + 428, + 26342, + 30, + 220, + 1069, + 419, + 409, + 419, + 356, + 87, + 198, + 31221, + 25, + 19145, + 1894, ] assert tokenized_example['context'][:len(tokenized_input)].tolist( ) == tokenized_input @@ -541,10 +647,12 @@ def test_tokenize_example_with_tokenize_labels( @pytest.mark.filterwarnings( - r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning', ) def test_tokenize_example_with_no_tokenize_labels( - tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): tokenizer = tiny_gpt2_tokenizer seqlen = 2048 num_fewshot = 0 @@ -569,13 +677,30 @@ def test_tokenize_example_with_no_tokenize_labels( destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - tokenize_labels=False) - tokenized_example = dl.tokenize_example('What spell does this invoke? ', - 'exort exort wex\nSpell: ', - {'answer': ' Meatball'}) + tokenize_labels=False, + ) + tokenized_example = dl.tokenize_example( + 'What spell does this invoke? ', + 'exort exort wex\nSpell: ', + {'answer': ' Meatball'}, + ) tokenized_input = [ - 2061, 4822, 857, 428, 26342, 30, 220, 1069, 419, 409, 419, 356, 87, 198, - 31221, 25 + 2061, + 4822, + 857, + 428, + 26342, + 30, + 220, + 1069, + 419, + 409, + 419, + 356, + 87, + 198, + 31221, + 25, ] assert tokenized_example['context'][:len(tokenized_input)].tolist( ) == tokenized_input @@ -589,7 +714,8 @@ def test_qa_set_cot_no_cot(tmp_path: Path): dataset_uri = f'{local_data}/triviaqa_small.jsonl' tokenizer = transformers.AutoTokenizer.from_pretrained( - 'facebook/opt-125m') # type: ignore reportUnboundVariable + 'facebook/opt-125m', + ) # type: ignore reportUnboundVariable tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) @@ -613,7 +739,8 @@ def test_qa_set_cot_has_cot(tmp_path: Path): dataset_uri = f'{local_data}/gsm8k_small.jsonl' tokenizer = transformers.AutoTokenizer.from_pretrained( - 'facebook/opt-125m') # type: ignore reportUnboundVariable + 'facebook/opt-125m', + ) # type: ignore reportUnboundVariable tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) @@ -633,7 +760,9 @@ def test_qa_set_cot_has_cot(tmp_path: Path): def test_qa_get_max_answer_length( - tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/triviaqa_small.jsonl' tokenizer = tiny_gpt2_tokenizer @@ -658,7 +787,9 @@ def test_qa_get_max_answer_length( def test_qa_get_answer_from_example_with_no_cot( - tmp_path: Path, tiny_gpt2_tokenizer: transformers.AutoTokenizer): + tmp_path: Path, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/triviaqa_small.jsonl' @@ -681,13 +812,15 @@ def test_qa_get_answer_from_example_with_no_cot( answer = dl.get_answer_from_example({ 'context': 'empty', 'answer': 'this is the correct answer', - 'chain_of_thought': "Let's think step by step. " + 'chain_of_thought': "Let's think step by step. ", }) assert answer == 'this is the correct answer' def test_qa_get_answer_from_example_with_cot( - tmp_path: Path, tiny_gpt2_tokenizer: transformers.AutoTokenizer): + tmp_path: Path, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/triviaqa_small.jsonl' @@ -711,13 +844,15 @@ def test_qa_get_answer_from_example_with_cot( answer = dl.get_answer_from_example({ 'context': 'empty', 'answer': 'this is the correct answer', - 'chain_of_thought': "Let's think step by step. " + 'chain_of_thought': "Let's think step by step. ", }) assert answer == "Let's think step by step. ### this is the correct answer" -def test_qa_tokenize_example(tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): +def test_qa_tokenize_example( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/triviaqa_small.jsonl' @@ -739,20 +874,26 @@ def test_qa_tokenize_example(tiny_gpt2_tokenizer: transformers.AutoTokenizer, ) dl.has_cot = True tokenized_example = dl.tokenize_example( - 'starting prompt', 'a context', { + 'starting prompt', + 'a context', + { 'context': 'empty', 'answer': 'this is the correct answer', 'aliases': ['this is the right answer', 'this is the best answer'], - 'chain_of_thought': "Let's think step by step. " - }) + 'chain_of_thought': "Let's think step by step. ", + }, + ) assert 'aliases' in tokenized_example assert tokenized_example['aliases'] == [ - 'this is the right answer', 'this is the best answer' + 'this is the right answer', + 'this is the best answer', ] -def test_code_adjust_padding(tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): +def test_code_adjust_padding( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/human_eval_small.jsonl' tokenizer = tiny_gpt2_tokenizer @@ -778,12 +919,14 @@ def test_code_adjust_padding(tiny_gpt2_tokenizer: transformers.AutoTokenizer, ) assert all( - len(data['prompt']) == 148 - for data in dl.dataset) # pyright: ignore [reportGeneralTypeIssues] + len(data['prompt']) == 148 for data in dl.dataset + ) # pyright: ignore [reportGeneralTypeIssues] -def test_code_update_gen_kwargs(tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): +def test_code_update_gen_kwargs( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/human_eval_small.jsonl' tokenizer = tiny_gpt2_tokenizer @@ -813,8 +956,10 @@ def test_code_update_gen_kwargs(tiny_gpt2_tokenizer: transformers.AutoTokenizer, assert dl.base_batch['generation_kwargs']['do_sample'] == True -def test_mc_tokenize_example(tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): +def test_mc_tokenize_example( + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/mmlu_small.jsonl' tokenizer = tiny_gpt2_tokenizer @@ -839,12 +984,13 @@ def test_mc_tokenize_example(tiny_gpt2_tokenizer: transformers.AutoTokenizer, "Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: ", 'choices': ['A', 'B', 'C', 'D'], 'gold': - 2 + 2, } tokenized_example = dl.tokenize_example( prompt_and_fewshot='Answer the following: ', ctxt=example['context'], - example=example) + example=example, + ) unpadded_queries = [ context[context != tokenizer.eos_token_id] for context in tokenized_example['query'] @@ -856,15 +1002,17 @@ def test_mc_tokenize_example(tiny_gpt2_tokenizer: transformers.AutoTokenizer, "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: A", "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: B", "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: C", - "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: D" + "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: D", ] assert untokenized_inputs == correct_output @pytest.mark.parametrize('prelimiter', ['', 'This is a question: ']) def test_schema_construct_context( - prelimiter: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): + prelimiter: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/winograd_small.jsonl' tokenizer = tiny_gpt2_tokenizer @@ -887,7 +1035,7 @@ def test_schema_construct_context( example = { 'context_options': ['cont one', 'cont two'], 'gold': 0, - 'continuation': 'this is a continuation' + 'continuation': 'this is a continuation', } constructed_context = dl.construct_context(example) assert constructed_context == f'{prelimiter}cont one ### this is a continuation' @@ -924,21 +1072,27 @@ def test_schema_construct_multiple_contexts( example = { 'context_options': [f'cont one', 'cont two'], 'gold': 0, - 'continuation': 'this is a continuation' + 'continuation': 'this is a continuation', } constructed_contexts = dl._construct_multiple_contexts(example) assert constructed_contexts == [ - f'{prelimiter}cont one', f'{prelimiter}cont two' + f'{prelimiter}cont one', + f'{prelimiter}cont two', ] constructed_contexts = dl._construct_multiple_contexts( - example, preceding_text='some text') + example, + preceding_text='some text', + ) assert constructed_contexts == [ - f'{prelimiter}\ncont one ###', f'{prelimiter}\ncont two ###' + f'{prelimiter}\ncont one ###', + f'{prelimiter}\ncont two ###', ] def test_schema_tokenize_example( - tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/winograd_small.jsonl' tokenizer = tiny_gpt2_tokenizer @@ -956,21 +1110,24 @@ def test_schema_tokenize_example( prompt_string=prompt_string, # pyright: ignore example_delimiter='\n', # pyright: ignore continuation_delimiter=' ### ', - destination_path=str(tmp_path / - 'test_human_eval_small.jsonl'), # pyright: ignore + destination_path=str( + tmp_path / 'test_human_eval_small.jsonl', + ), # pyright: ignore ) example = { 'context_options': ['context one', 'context two'], 'gold': 0, - 'continuation': 'this is a continuation' + 'continuation': 'this is a continuation', } tokenized_example = dl.tokenize_example( prompt_and_fewshot='prompt ', context_options=example['context_options'], - example=example) + example=example, + ) assert all( tiny_gpt2_tokenizer.decode(cont) == ' this is a continuation' - for cont in tokenized_example['answer']) + for cont in tokenized_example['answer'] + ) unpadded_inputs = [ context[context != tokenizer.eos_token_id] for context in tokenized_example['context_options'] @@ -980,14 +1137,16 @@ def test_schema_tokenize_example( ] assert untokenized_inputs == [ 'prompt context one this is a continuation', - 'prompt context two this is a continuation' + 'prompt context two this is a continuation', ] @pytest.mark.parametrize('dataset_uri', ['mmlu_small.jsonl']) def test_mc_task_dataloader_subcategories( - dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1008,7 +1167,8 @@ def test_mc_task_dataloader_subcategories( example_delimiter='\n', continuation_delimiter='Answer: ', destination_path=str(tmp_path / 'icl.jsonl'), - has_categories=True) + has_categories=True, + ) assert isinstance(dls, dict) assert 'computer_security' in dls @@ -1022,7 +1182,8 @@ def test_mc_task_dataloader_subcategories( assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) assert 'continuation_indices' in batch assert isinstance(batch['continuation_indices'], list) and len( - batch['continuation_indices']) == batch_size + batch['continuation_indices'], + ) == batch_size assert 'mode' in batch assert batch['mode'] == 'icl_task' min_idx = min(batch['continuation_indices'][0]).item() @@ -1034,8 +1195,10 @@ def test_mc_task_dataloader_subcategories( 'pubmed_sm.jsonl', ]) def test_lm_task_dataloader_extra_space( - dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1043,17 +1206,19 @@ def test_lm_task_dataloader_extra_space( dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 2 seqlen = 64 - dl = get_icl_task_dataloader('language_modeling', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=10, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' ', - destination_path=str(tmp_path / 'icl.jsonl')) + dl = get_icl_task_dataloader( + 'language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=10, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' ', + destination_path=str(tmp_path / 'icl.jsonl'), + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright batch = next(dl.dataloader._get_iterator()) @@ -1064,22 +1229,26 @@ def test_lm_task_dataloader_extra_space( assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) assert 'continuation_indices' in batch assert isinstance(batch['continuation_indices'], list) and len( - batch['continuation_indices']) == batch_size + batch['continuation_indices'], + ) == batch_size assert 'mode' in batch assert batch['mode'] == 'icl_task' min_idx = min(batch['continuation_indices'][0]).item() max_idx = max(batch['continuation_indices'][0]).item() assert ' ' not in tokenizer.decode(batch['input_ids'][0][0:max_idx + 1]) - assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + - 1]) == ' yes' + assert tokenizer.decode( + batch['input_ids'][0][min_idx:max_idx + 1], + ) == ' yes' @pytest.mark.parametrize('dataset_uri', [ 'lambada_small.jsonl', ]) -def test_lm_task_dataloader(dataset_uri: str, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): +def test_lm_task_dataloader( + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1087,17 +1256,19 @@ def test_lm_task_dataloader(dataset_uri: str, dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 2 seqlen = 64 - dl = get_icl_task_dataloader('language_modeling', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', - destination_path=str(tmp_path / 'icl.jsonl')) + dl = get_icl_task_dataloader( + 'language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + destination_path=str(tmp_path / 'icl.jsonl'), + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright batch = next(dl.dataloader._get_iterator()) @@ -1108,20 +1279,25 @@ def test_lm_task_dataloader(dataset_uri: str, assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) assert 'continuation_indices' in batch assert isinstance(batch['continuation_indices'], list) and len( - batch['continuation_indices']) == batch_size + batch['continuation_indices'], + ) == batch_size assert 'mode' in batch assert batch['mode'] == 'icl_task' min_idx = min(batch['continuation_indices'][0]).item() max_idx = max(batch['continuation_indices'][0]).item() - assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + - 1]) == ' glen' + assert tokenizer.decode( + batch['input_ids'][0][min_idx:max_idx + 1], + ) == ' glen' @pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) @pytest.mark.parametrize('prelimiter', ['', 'This is a question: ']) -def test_schema_task_dataloader(dataset_uri: str, prelimiter: str, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): +def test_schema_task_dataloader( + dataset_uri: str, + prelimiter: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1129,18 +1305,20 @@ def test_schema_task_dataloader(dataset_uri: str, prelimiter: str, dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 2 seqlen = 64 - dl = get_icl_task_dataloader('schema', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - question_prelimiter=prelimiter, - continuation_delimiter='', - destination_path=str(tmp_path / 'icl.jsonl')) + dl = get_icl_task_dataloader( + 'schema', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + example_delimiter='\n', + question_prelimiter=prelimiter, + continuation_delimiter='', + destination_path=str(tmp_path / 'icl.jsonl'), + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) batch = next(dl.dataloader._get_iterator()) @@ -1152,44 +1330,53 @@ def test_schema_task_dataloader(dataset_uri: str, prelimiter: str, assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) assert 'continuation_indices' in batch assert isinstance(batch['continuation_indices'], list) and len( - batch['continuation_indices']) == batch_size + batch['continuation_indices'], + ) == batch_size assert 'mode' in batch assert batch['mode'] == 'icl_task' assert 'gold_indices' in batch assert isinstance(batch['gold_indices'], list) and len( - batch['gold_indices']) == batch_size // choices_per_question + batch['gold_indices'], + ) == batch_size // choices_per_question assert 'choice_groupings' in batch assert isinstance(batch['choice_groupings'], list) and len( - batch['choice_groupings']) == batch_size // choices_per_question + batch['choice_groupings'], + ) == batch_size // choices_per_question min_idx = min(batch['continuation_indices'][0]).item() max_idx = max(batch['continuation_indices'][0]).item() - assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + - 1]) == ' feared violence.' + assert tokenizer.decode( + batch['input_ids'][0][min_idx:max_idx + 1], + ) == ' feared violence.' @pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) -def test_schema_task_dataloader_sentpiece_tokenizer(dataset_uri: str, - tmp_path: Path): +def test_schema_task_dataloader_sentpiece_tokenizer( + dataset_uri: str, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') tokenizer = transformers.AutoTokenizer.from_pretrained( 'huggyllama/llama-7b', # type: ignore reportUnboundVariable - use_fast=False) + use_fast=False, + ) dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 2 seqlen = 64 - dl = get_icl_task_dataloader('schema', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' ', - destination_path=str(tmp_path / 'icl.jsonl')) + dl = get_icl_task_dataloader( + 'schema', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' ', + destination_path=str(tmp_path / 'icl.jsonl'), + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) batch = next(dl.dataloader._get_iterator()) @@ -1201,27 +1388,33 @@ def test_schema_task_dataloader_sentpiece_tokenizer(dataset_uri: str, assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) assert 'continuation_indices' in batch assert isinstance(batch['continuation_indices'], list) and len( - batch['continuation_indices']) == batch_size + batch['continuation_indices'], + ) == batch_size assert 'mode' in batch assert batch['mode'] == 'icl_task' assert 'gold_indices' in batch assert isinstance(batch['gold_indices'], list) and len( - batch['gold_indices']) == batch_size // choices_per_question + batch['gold_indices'], + ) == batch_size // choices_per_question assert 'choice_groupings' in batch assert isinstance(batch['choice_groupings'], list) and len( - batch['choice_groupings']) == batch_size // choices_per_question + batch['choice_groupings'], + ) == batch_size // choices_per_question max_idx = max(batch['continuation_indices'][0]).item() assert tokenizer.decode( - batch['input_ids'][0][0:max_idx + 1] + batch['input_ids'][0][0:max_idx + 1], ) == "The trophy doesn't fit into the brown suitcase because the suitcase is too small. \nThe city councilmen refused the demonstrators a permit because the city councilmen feared violence." @pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) @pytest.mark.parametrize('num_fewshot', [0, 1]) def test_lm_task_dataloader_opt_tokenizer( - tiny_opt_tokenizer: transformers.AutoTokenizer, dataset_uri: str, - num_fewshot: int, tmp_path: Path): + tiny_opt_tokenizer: transformers.AutoTokenizer, + dataset_uri: str, + num_fewshot: int, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1229,17 +1422,19 @@ def test_lm_task_dataloader_opt_tokenizer( dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 2 seqlen = 512 - dl = get_icl_task_dataloader('language_modeling', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', - destination_path=str(tmp_path / 'icl.jsonl')) + dl = get_icl_task_dataloader( + 'language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + destination_path=str(tmp_path / 'icl.jsonl'), + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright batch = next(dl.dataloader._get_iterator()) @@ -1250,13 +1445,15 @@ def test_lm_task_dataloader_opt_tokenizer( assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) assert 'continuation_indices' in batch assert isinstance(batch['continuation_indices'], list) and len( - batch['continuation_indices']) == batch_size + batch['continuation_indices'], + ) == batch_size assert 'mode' in batch assert batch['mode'] == 'icl_task' min_idx = min(batch['continuation_indices'][0]).item() max_idx = max(batch['continuation_indices'][0]).item() - assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + - 1]) == ' glen' + assert tokenizer.decode( + batch['input_ids'][0][min_idx:max_idx + 1], + ) == ' glen' assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).startswith('') assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).count('') == 1 @@ -1264,8 +1461,11 @@ def test_lm_task_dataloader_opt_tokenizer( @pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl']) @pytest.mark.parametrize('num_fewshot', [0, 1]) def test_mc_task_dataloader_opt_tokenizer( - tiny_opt_tokenizer: transformers.AutoTokenizer, dataset_uri: str, - num_fewshot: int, tmp_path: Path): + tiny_opt_tokenizer: transformers.AutoTokenizer, + dataset_uri: str, + num_fewshot: int, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1274,17 +1474,19 @@ def test_mc_task_dataloader_opt_tokenizer( dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 4 seqlen = 64 - dl = get_icl_task_dataloader('multiple_choice', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', - destination_path=str(tmp_path / 'icl.jsonl')) + dl = get_icl_task_dataloader( + 'multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(tmp_path / 'icl.jsonl'), + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright batch = next(dl.dataloader._get_iterator()) @@ -1297,28 +1499,36 @@ def test_mc_task_dataloader_opt_tokenizer( assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) assert 'continuation_indices' in batch assert isinstance(batch['continuation_indices'], list) and len( - batch['continuation_indices']) == batch_size + batch['continuation_indices'], + ) == batch_size assert 'mode' in batch assert batch['mode'] == 'icl_task' assert 'gold_indices' in batch assert isinstance(batch['gold_indices'], list) and len( - batch['gold_indices']) == batch_size // choices_per_question + batch['gold_indices'], + ) == batch_size // choices_per_question assert 'choice_groupings' in batch assert isinstance(batch['choice_groupings'], list) and len( - batch['choice_groupings']) == batch_size // choices_per_question + batch['choice_groupings'], + ) == batch_size // choices_per_question min_idx = min(batch['continuation_indices'][0]).item() max_idx = max(batch['continuation_indices'][0]).item() - assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + - 1]) == ' Pour it onto a plate' + assert tokenizer.decode( + batch['input_ids'][0][min_idx:max_idx + 1], + ) == ' Pour it onto a plate' assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).startswith('') assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).count('') == 1 @pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl']) @pytest.mark.parametrize('num_fewshot', [0, 1]) -def test_mc_split_batch(tiny_opt_tokenizer: transformers.AutoTokenizer, - dataset_uri: str, num_fewshot: int, tmp_path: Path): +def test_mc_split_batch( + tiny_opt_tokenizer: transformers.AutoTokenizer, + dataset_uri: str, + num_fewshot: int, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1327,17 +1537,19 @@ def test_mc_split_batch(tiny_opt_tokenizer: transformers.AutoTokenizer, dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 4 seqlen = 512 - dl = get_icl_task_dataloader('multiple_choice', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', - destination_path=str(tmp_path / 'icl.jsonl')) + dl = get_icl_task_dataloader( + 'multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(tmp_path / 'icl.jsonl'), + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright batch = next(dl.dataloader._get_iterator()) @@ -1349,45 +1561,52 @@ def test_mc_split_batch(tiny_opt_tokenizer: transformers.AutoTokenizer, for i, microbatch in enumerate(microbatches): assert dl.get_num_samples_in_batch(microbatch) == 1 assert 'input_ids' in microbatch - assert tuple(microbatch['input_ids'].shape) == (real_microbatch_size, - seqlen) + assert tuple( + microbatch['input_ids'].shape, + ) == (real_microbatch_size, seqlen) assert 'attention_mask' in microbatch assert tuple( - microbatch['attention_mask'].shape) == (real_microbatch_size, - seqlen) + microbatch['attention_mask'].shape, + ) == (real_microbatch_size, seqlen) assert 'continuation_indices' in microbatch assert isinstance(microbatch['continuation_indices'], list) and len( - microbatch['continuation_indices']) == real_microbatch_size + microbatch['continuation_indices'], + ) == real_microbatch_size assert 'mode' in microbatch assert microbatch['mode'] == 'icl_task' assert 'gold_indices' in microbatch assert isinstance(microbatch['gold_indices'], list) and len( - microbatch['gold_indices'] + microbatch['gold_indices'], ) == real_microbatch_size // choices_per_question assert 'choice_groupings' in microbatch assert isinstance(microbatch['choice_groupings'], list) and len( - microbatch['choice_groupings'] + microbatch['choice_groupings'], ) == real_microbatch_size // choices_per_question min_idx = min(microbatch['continuation_indices'][0]).item() max_idx = max(microbatch['continuation_indices'][0]).item() if i == 0: assert tokenizer.decode( - microbatch['input_ids'][0][min_idx:max_idx + - 1]) == ' Pour it onto a plate' + microbatch['input_ids'][0][min_idx:max_idx + 1], + ) == ' Pour it onto a plate' elif i == 1: assert tokenizer.decode( - microbatch['input_ids'][0][min_idx:max_idx + 1] + microbatch['input_ids'][0][min_idx:max_idx + 1], ) == ' Weld the metal together to get it to stay firmly in place' assert tokenizer.decode( - microbatch['input_ids'][0][0:min_idx]).startswith('') + microbatch['input_ids'][0][0:min_idx], + ).startswith('') assert tokenizer.decode( - microbatch['input_ids'][0][0:min_idx]).count('') == 1 + microbatch['input_ids'][0][0:min_idx], + ).count('') == 1 @pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) -def test_qa_split_batch(tiny_opt_tokenizer: transformers.AutoTokenizer, - dataset_uri: str, tmp_path: Path): +def test_qa_split_batch( + tiny_opt_tokenizer: transformers.AutoTokenizer, + dataset_uri: str, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') dataset_uri = f'{local_data}/{dataset_uri}' @@ -1442,8 +1661,12 @@ def test_qa_split_batch(tiny_opt_tokenizer: transformers.AutoTokenizer, @pytest.mark.parametrize('num_fewshot', [0]) @pytest.mark.parametrize('prompt_string', ['I am a prompt', '']) def test_qa_task_dataloader_w_null_eos( - dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path, num_fewshot: int, prompt_string: str): + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + num_fewshot: int, + prompt_string: str, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1453,28 +1676,32 @@ def test_qa_task_dataloader_w_null_eos( seqlen = 512 tiny_gpt2_tokenizer.eos_token_id = None with pytest.raises(ValueError): - _ = get_icl_task_dataloader('generation_task_with_answers', - dataset_uri, - tokenizer, - batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - question_prelimiter='Q: ', - continuation_delimiter='\nA:', - destination_path=str( - tmp_path / f'icl_{num_fewshot}.jsonl')) + _ = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri, + tokenizer, + batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + question_prelimiter='Q: ', + continuation_delimiter='\nA:', + destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + ) @pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) @pytest.mark.parametrize('num_fewshot', [0, 2]) @pytest.mark.parametrize('prompt_string', ['I am a prompt', '']) -def test_qa_task_dataloader(dataset_uri: str, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path, num_fewshot: int, - prompt_string: str): +def test_qa_task_dataloader( + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + num_fewshot: int, + prompt_string: str, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1484,28 +1711,31 @@ def test_qa_task_dataloader(dataset_uri: str, seqlen = 512 # empirical number from the small test dataset maximum_answer_length = 7 - dl = get_icl_task_dataloader('generation_task_with_answers', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - question_prelimiter='Q: ', - continuation_delimiter='\nA:', - destination_path=str( - tmp_path / f'icl_{num_fewshot}.jsonl')) + dl = get_icl_task_dataloader( + 'generation_task_with_answers', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + question_prelimiter='Q: ', + continuation_delimiter='\nA:', + destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright batch = next(dl.dataloader._get_iterator()) - assert tuple(batch['input_ids'].shape) == (batch_size, - seqlen - maximum_answer_length) - assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - - maximum_answer_length) + assert tuple( + batch['input_ids'].shape, + ) == (batch_size, seqlen - maximum_answer_length) + assert tuple( + batch['attention_mask'].shape, + ) == (batch_size, seqlen - maximum_answer_length) assert batch['mode'] == 'generate' # the maximum generation length from the small test data @@ -1519,20 +1749,26 @@ def test_qa_task_dataloader(dataset_uri: str, if len(prompt_string) > 0: assert all(item.count('I am a prompt') == 1 for item in decoded_batch) assert all( - set(found) == set(expected) for found, expected in zip( - batch['labels'], [['David Seville'], ['Skorpio', 'Scorpio']])) + set(found) == set(expected) for found, expected in + zip(batch['labels'], [['David Seville'], ['Skorpio', 'Scorpio']]) + ) assert decoded_batch[0].endswith( - 'Q: Who was the man behind The Chipmunks?\nA:') + 'Q: Who was the man behind The Chipmunks?\nA:', + ) assert decoded_batch[1].endswith( - 'Q: What star sign is Jamie Lee Curtis?\nA:') + 'Q: What star sign is Jamie Lee Curtis?\nA:', + ) assert 'eos_token_id' in batch['generation_kwargs'] @pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl']) @pytest.mark.parametrize('num_fewshot', [0, 2]) def test_qa_task_with_cot_dataloader( - dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path, num_fewshot: int): + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + num_fewshot: int, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1555,14 +1791,17 @@ def test_qa_task_with_cot_dataloader( question_prelimiter='Q: ', continuation_delimiter="\nA: Let's think step by step. ", cot_delimiter=' #### ', - destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl')) + destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright batch = next(dl.dataloader._get_iterator()) - assert tuple(batch['input_ids'].shape) == (batch_size, - seqlen - maximum_answer_length) - assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - - maximum_answer_length) + assert tuple( + batch['input_ids'].shape, + ) == (batch_size, seqlen - maximum_answer_length) + assert tuple( + batch['attention_mask'].shape, + ) == (batch_size, seqlen - maximum_answer_length) assert batch['mode'] == 'generate' # the maximum generation length from the small test data assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length @@ -1574,25 +1813,28 @@ def test_qa_task_with_cot_dataloader( assert batch['labels'] == [['18'], ['12334']] if num_fewshot == 0: assert decoded_batch[0].endswith( - "Q: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step." + "Q: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step.", ) assert decoded_batch[1].endswith( - "Q: A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?\nA: Let's think step by step." + "Q: A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?\nA: Let's think step by step.", ) elif num_fewshot == 2: assert decoded_batch[0].endswith( - "Q: Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?\nA: Let's think step by step. The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000 #### 70000\nQ: James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?\nA: Let's think step by step. He sprints 3*3=<<3*3=9>>9 times\nSo he runs 9*60=<<9*60=540>>540 meters #### 540\nQ: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step." + "Q: Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?\nA: Let's think step by step. The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000 #### 70000\nQ: James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?\nA: Let's think step by step. He sprints 3*3=<<3*3=9>>9 times\nSo he runs 9*60=<<9*60=540>>540 meters #### 540\nQ: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step.", ) assert decoded_batch[1].endswith( - "Q: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step. Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market. #### 18\nQ: Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?\nA: Let's think step by step. The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000 #### 70000\nQ: A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?\nA: Let's think step by step." + "Q: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nA: Let's think step by step. Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market. #### 18\nQ: Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?\nA: Let's think step by step. The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000 #### 70000\nQ: A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?\nA: Let's think step by step.", ) @pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl']) @pytest.mark.parametrize('prelimiter', ['', 'This is a question: ']) -def test_mc_task_dataloader(dataset_uri: str, prelimiter: str, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): +def test_mc_task_dataloader( + dataset_uri: str, + prelimiter: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1601,18 +1843,20 @@ def test_mc_task_dataloader(dataset_uri: str, prelimiter: str, batch_size = 2 seqlen = 64 example_delimiter = '\n' - dl = get_icl_task_dataloader('multiple_choice', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - question_prelimiter=prelimiter, - example_delimiter=example_delimiter, - continuation_delimiter='\nA: ', - destination_path=str(tmp_path / 'icl.jsonl')) + dl = get_icl_task_dataloader( + 'multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + question_prelimiter=prelimiter, + example_delimiter=example_delimiter, + continuation_delimiter='\nA: ', + destination_path=str(tmp_path / 'icl.jsonl'), + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright batch = next(dl.dataloader._get_iterator()) @@ -1624,28 +1868,33 @@ def test_mc_task_dataloader(dataset_uri: str, prelimiter: str, assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) assert 'continuation_indices' in batch assert isinstance(batch['continuation_indices'], list) and len( - batch['continuation_indices']) == batch_size + batch['continuation_indices'], + ) == batch_size assert 'mode' in batch assert batch['mode'] == 'icl_task' assert 'gold_indices' in batch assert isinstance(batch['gold_indices'], list) and len( - batch['gold_indices']) == batch_size // choices_per_question + batch['gold_indices'], + ) == batch_size // choices_per_question assert 'choice_groupings' in batch assert isinstance(batch['choice_groupings'], list) and len( - batch['choice_groupings']) == batch_size // choices_per_question + batch['choice_groupings'], + ) == batch_size // choices_per_question min_idx = min(batch['continuation_indices'][0]).item() max_idx = max(batch['continuation_indices'][0]).item() - assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + - 1]) == ' Pour it onto a plate' + assert tokenizer.decode( + batch['input_ids'][0][min_idx:max_idx + 1], + ) == ' Pour it onto a plate' q1 = 'how do you shake something?\nA: ' a1 = 'move it up and down and side to side quickly.' q2 = "When boiling butter, when it's ready, you can\nA:" assert tokenizer.decode( - batch['input_ids'][0][:min_idx] + batch['input_ids'][0][:min_idx], ) == f'{prelimiter}{q1}{a1}{example_delimiter}{prelimiter}{q2}' - assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + - 1]) == ' Pour it onto a plate' + assert tokenizer.decode( + batch['input_ids'][0][min_idx:max_idx + 1], + ) == ' Pour it onto a plate' @pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) @@ -1654,7 +1903,8 @@ def test_code_eval_split_batch(dataset_uri: str, tmp_path: Path): dataset_uri = f'{local_data}/{dataset_uri}' tokenizer = transformers.AutoTokenizer.from_pretrained( - 'EleutherAI/gpt-neox-20b') # type: ignore reportUnboundVariable + 'EleutherAI/gpt-neox-20b', + ) # type: ignore reportUnboundVariable tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) @@ -1708,9 +1958,13 @@ def test_code_eval_split_batch(dataset_uri: str, tmp_path: Path): @pytest.mark.parametrize('prompt_string', ['Please code:\n', '']) @pytest.mark.parametrize('generations_per_sample', [1, 3]) def test_code_eval_sentpiece_dataloader( - dataset_uri: str, tmp_path: Path, num_fewshot: int, prompt_string: str, - generations_per_sample: int, - tiny_llama_tokenizer: transformers.AutoTokenizer): + dataset_uri: str, + tmp_path: Path, + num_fewshot: int, + prompt_string: str, + generations_per_sample: int, + tiny_llama_tokenizer: transformers.AutoTokenizer, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -1719,20 +1973,21 @@ def test_code_eval_sentpiece_dataloader( batch_size = 5 seqlen = 2048 - dl = get_icl_task_dataloader('code_evaluation', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - continuation_delimiter='', - question_prelimiter='Code start: \n', - destination_path=str( - tmp_path / f'icl_{num_fewshot}.jsonl'), - generations_per_sample=generations_per_sample) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + generations_per_sample=generations_per_sample, + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1753,8 +2008,9 @@ def test_code_eval_sentpiece_dataloader( assert batch['mode'] == 'generate' # the maximum generation length from the small test data assert batch['generation_kwargs']['max_new_tokens'] == 129 - has_left_padding.extend( - [item[0] == tokenizer.eos_token_id for item in batch['input_ids']]) + has_left_padding.extend([ + item[0] == tokenizer.eos_token_id for item in batch['input_ids'] + ]) assert not all(has_left_padding) # longest should be pushed left decoded_batches = [ @@ -1763,11 +2019,13 @@ def test_code_eval_sentpiece_dataloader( for decoded_batch in decoded_batches: assert all( item.count('Code start: \n') == num_fewshot + 1 - for item in decoded_batch) + for item in decoded_batch + ) if len(prompt_string) > 0: assert all( - item.count('Please code:\n') == 1 for item in decoded_batch) + item.count('Please code:\n') == 1 for item in decoded_batch + ) labels = [ ' for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n\n return False\n', @@ -1781,7 +2039,7 @@ def test_code_eval_sentpiece_dataloader( "Code start: \nfrom typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", "Code start: \nfrom typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n", "Code start: \n\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n", - "Code start: \nfrom typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n" + "Code start: \nfrom typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n", ] for i in range(4): for j in range(generations_per_sample): @@ -1796,24 +2054,27 @@ def test_code_eval_test_cases(dataset_uri: str, tmp_path: Path): local_data = os.path.join(os.path.dirname(__file__), 'local_data') tokenizer = transformers.AutoTokenizer.from_pretrained( - 'huggyllama/llama-7b') # type: ignore reportUnboundVariable + 'huggyllama/llama-7b', + ) # type: ignore reportUnboundVariable dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 4 seqlen = 512 - dl = get_icl_task_dataloader('code_evaluation', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', - question_prelimiter='Code start: \n', - destination_path=str(tmp_path / f'icl_.jsonl'), - generations_per_sample=1) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str(tmp_path / f'icl_.jsonl'), + generations_per_sample=1, + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1823,18 +2084,24 @@ def test_code_eval_test_cases(dataset_uri: str, tmp_path: Path): if isinstance(dl.dataloader.dataset, InContextLearningCodeEvalDataset): max_prompt_length = dl.dataloader.dataset.max_prompt_length assert tuple(batch['input_ids'].shape) == (batch_size, max_prompt_length) - assert tuple(batch['attention_mask'].shape) == (batch_size, - max_prompt_length) + assert tuple( + batch['attention_mask'].shape, + ) == (batch_size, max_prompt_length) assert batch['mode'] == 'generate' # the maximum generation length from the small test data assert batch['generation_kwargs']['max_new_tokens'] == 129 - assert any(item[0] != tokenizer.eos_token_id - for item in batch['input_ids']) # longest should be pushed left + assert any( + item[0] != tokenizer.eos_token_id for item in batch['input_ids'] + ) # longest should be pushed left mod = types.ModuleType('test_module') for prompt, solution, inputs, outputs, entry_point in zip( - batch['prompts'], batch['labels'], batch['test_inputs'], - batch['test_outputs'], batch['entry_points']): + batch['prompts'], + batch['labels'], + batch['test_inputs'], + batch['test_outputs'], + batch['entry_points'], + ): exec(prompt + solution, mod.__dict__) for test_input, test_output in zip(inputs, outputs): result = mod.__dict__[entry_point](*eval(test_input)) @@ -1846,62 +2113,71 @@ def test_code_eval_pass_at_k_validity(dataset_uri: str, tmp_path: Path): local_data = os.path.join(os.path.dirname(__file__), 'local_data') tokenizer = transformers.AutoTokenizer.from_pretrained( - 'huggyllama/llama-7b') # type: ignore reportUnboundVariable + 'huggyllama/llama-7b', + ) # type: ignore reportUnboundVariable dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 2 seqlen = 64 with pytest.raises(ValueError, match=r'.* pass_at_k .*'): - get_icl_task_dataloader('code_evaluation', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', - question_prelimiter='Code start: \n', - destination_path=str(tmp_path / f'icl_.jsonl'), - pass_at_k=10, - generations_per_sample=1) + get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str(tmp_path / f'icl_.jsonl'), + pass_at_k=10, + generations_per_sample=1, + ) @pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) @pytest.mark.parametrize('num_fewshot', [0, 2]) @pytest.mark.parametrize('prompt_string', ['Please code:\n', '']) @pytest.mark.parametrize('generations_per_sample', [1, 3]) -def test_code_eval_task_dataloader(dataset_uri: str, tmp_path: Path, - num_fewshot: int, prompt_string: str, - generations_per_sample: int): +def test_code_eval_task_dataloader( + dataset_uri: str, + tmp_path: Path, + num_fewshot: int, + prompt_string: str, + generations_per_sample: int, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') tokenizer = transformers.AutoTokenizer.from_pretrained( - 'mosaicml/mpt-7b') # type: ignore reportUnboundVariable + 'mosaicml/mpt-7b', + ) # type: ignore reportUnboundVariable dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 4 seqlen = 2048 - dl = get_icl_task_dataloader('code_evaluation', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - continuation_delimiter='', - question_prelimiter='Code start: \n', - destination_path=str( - tmp_path / f'icl_{num_fewshot}.jsonl'), - generations_per_sample=generations_per_sample, - generation_kwargs={ - 'temperature': .9, - 'top_k': 40 - }) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string=prompt_string, + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + generations_per_sample=generations_per_sample, + generation_kwargs={ + 'temperature': .9, + 'top_k': 40, + }, + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1921,8 +2197,9 @@ def test_code_eval_task_dataloader(dataset_uri: str, tmp_path: Path, assert batch['mode'] == 'generate' # the maximum generation length from the small test data assert batch['generation_kwargs']['max_new_tokens'] == 122 - has_left_padding.extend( - [item[0] == tokenizer.eos_token_id for item in batch['input_ids']]) + has_left_padding.extend([ + item[0] == tokenizer.eos_token_id for item in batch['input_ids'] + ]) assert not all(has_left_padding) # longest should be pushed left decoded_batches = [ @@ -1931,11 +2208,13 @@ def test_code_eval_task_dataloader(dataset_uri: str, tmp_path: Path, for decoded_batch in decoded_batches: assert all( item.count('Code start: \n') == num_fewshot + 1 - for item in decoded_batch) + for item in decoded_batch + ) if len(prompt_string) > 0: assert all( - item.count('Please code:\n') == 1 for item in decoded_batch) + item.count('Please code:\n') == 1 for item in decoded_batch + ) labels = [ ' for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n\n return False\n', @@ -1949,7 +2228,7 @@ def test_code_eval_task_dataloader(dataset_uri: str, tmp_path: Path, "Code start: \nfrom typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", "Code start: \nfrom typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n", "Code start: \n\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n", - "Code start: \nfrom typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n" + "Code start: \nfrom typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n", ] for i in range(4): for j in range(generations_per_sample): @@ -1961,8 +2240,12 @@ def test_code_eval_task_dataloader(dataset_uri: str, tmp_path: Path, @pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl']) @pytest.mark.parametrize('num_fewshot', [0, 1]) -def test_eval_split_batch(mpt_tokenizer: transformers.AutoTokenizer, - dataset_uri: str, num_fewshot: int, tmp_path: Path): +def test_eval_split_batch( + mpt_tokenizer: transformers.AutoTokenizer, + dataset_uri: str, + num_fewshot: int, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') tokenizer = mpt_tokenizer # type: ignore reportUnboundVariable @@ -1970,24 +2253,25 @@ def test_eval_split_batch(mpt_tokenizer: transformers.AutoTokenizer, batch_size = 4 seqlen = 512 - dl = get_icl_task_dataloader('code_evaluation', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', - question_prelimiter='Code start: \n', - destination_path=str( - tmp_path / f'icl_{num_fewshot}.jsonl'), - generations_per_sample=1, - generation_kwargs={ - 'temperature': .9, - 'top_k': 40 - }) + dl = get_icl_task_dataloader( + 'code_evaluation', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='', + question_prelimiter='Code start: \n', + destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + generations_per_sample=1, + generation_kwargs={ + 'temperature': .9, + 'top_k': 40, + }, + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright batch = next(dl.dataloader._get_iterator()) @@ -2015,10 +2299,13 @@ def test_eval_split_batch(mpt_tokenizer: transformers.AutoTokenizer, @pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) # @pytest.mark.gpu # @pytest.mark.world_size(2) -def test_lm_task_evaluation(num_fewshot: int, dataset_uri: str, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path, - tiny_gpt2_model: transformers.AutoModelForCausalLM): +def test_lm_task_evaluation( + num_fewshot: int, + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + tiny_gpt2_model: transformers.AutoModelForCausalLM, +): in_memory_logger = InMemoryLogger( ) # track the logged metrics in the in_memory_logger @@ -2040,9 +2327,11 @@ def test_lm_task_evaluation(num_fewshot: int, dataset_uri: str, destination_path=str(tmp_path / 'icl.jsonl'), ) - evaluator = Evaluator(label='lambada', - dataloader=dl, - metric_names=['InContextLearningLMAccuracy']) + evaluator = Evaluator( + label='lambada', + dataloader=dl, + metric_names=['InContextLearningLMAccuracy'], + ) model = HuggingFaceModel( model=tiny_gpt2_model, @@ -2063,9 +2352,12 @@ def test_lm_task_evaluation(num_fewshot: int, dataset_uri: str, @pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl']) @pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') def test_schema_task_evaluation( - num_fewshot: int, dataset_uri: str, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path, - tiny_gpt2_model: transformers.AutoModelForCausalLM): + num_fewshot: int, + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + tiny_gpt2_model: transformers.AutoModelForCausalLM, +): in_memory_logger = InMemoryLogger( ) # track the logged metrics in the in_memory_logger @@ -2090,7 +2382,8 @@ def test_schema_task_evaluation( evaluator = Evaluator( label='winograd', dataloader=dl, - metric_names=['InContextLearningMultipleChoiceAccuracy']) + metric_names=['InContextLearningMultipleChoiceAccuracy'], + ) model = HuggingFaceModel( model=tiny_gpt2_model, @@ -2120,9 +2413,12 @@ def test_schema_task_evaluation( @pytest.mark.world_size(2) @pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') def test_mc_task_evaluation_subcategories( - dataset_uri: str, num_fewshot: int, - tiny_gpt2_model: transformers.AutoModelForCausalLM, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, tmp_path: Path): + dataset_uri: str, + num_fewshot: int, + tiny_gpt2_model: transformers.AutoModelForCausalLM, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): in_memory_logger = InMemoryLogger( ) # track the logged metrics in the in_memory_logger @@ -2134,26 +2430,28 @@ def test_mc_task_evaluation_subcategories( tmp_path_to_broadcast = str(os.path.abspath(tmp_path)) gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) reproducibility.seed_all(1234) - dls = get_icl_task_dataloader('multiple_choice', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=max_seq_len, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', - destination_path=str( - Path(gathered_paths[0]) / 'icl.jsonl'), - has_categories=True) + dls = get_icl_task_dataloader( + 'multiple_choice', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=max_seq_len, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ', + destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + has_categories=True, + ) assert isinstance(dls, dict) evaluators = [ - Evaluator(label='mmlu/' + k, - dataloader=dl, - metric_names=['InContextLearningMultipleChoiceAccuracy']) - for k, dl in dls.items() + Evaluator( + label='mmlu/' + k, + dataloader=dl, + metric_names=['InContextLearningMultipleChoiceAccuracy'], + ) for k, dl in dls.items() ] model = HuggingFaceModel( @@ -2168,24 +2466,29 @@ def test_mc_task_evaluation_subcategories( assert 'metrics/mmlu/computer_security/InContextLearningMultipleChoiceAccuracy' in in_memory_logger.data.keys( ) assert in_memory_logger.data[ - 'metrics/mmlu/computer_security/InContextLearningMultipleChoiceAccuracy'][ - 0][1].item() >= 0 + 'metrics/mmlu/computer_security/InContextLearningMultipleChoiceAccuracy' + ][0][1].item() >= 0 total = trainer.state.eval_metrics['mmlu/computer_security'][ 'InContextLearningMultipleChoiceAccuracy'].total dist.all_reduce(total) # type: ignore assert total.item() == 4 # type: ignore -@pytest.mark.parametrize('dataset_uri', - ['piqa_small.jsonl', 'hellaswag_small.jsonl']) +@pytest.mark.parametrize( + 'dataset_uri', + ['piqa_small.jsonl', 'hellaswag_small.jsonl'], +) @pytest.mark.parametrize('num_fewshot', [0, 5]) @pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') @pytest.mark.gpu @pytest.mark.world_size(2) -def test_mc_task_evaluation(num_fewshot: int, dataset_uri: str, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path, - tiny_gpt2_model: transformers.AutoModelForCausalLM): +def test_mc_task_evaluation( + num_fewshot: int, + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + tiny_gpt2_model: transformers.AutoModelForCausalLM, +): in_memory_logger = InMemoryLogger( ) # track the logged metrics in the in_memory_logger @@ -2215,7 +2518,8 @@ def test_mc_task_evaluation(num_fewshot: int, dataset_uri: str, evaluator = Evaluator( label='mc', dataloader=dl, - metric_names=['InContextLearningMultipleChoiceAccuracy']) + metric_names=['InContextLearningMultipleChoiceAccuracy'], + ) model = HuggingFaceModel( model=tiny_gpt2_model, @@ -2243,15 +2547,18 @@ def test_mc_task_evaluation(num_fewshot: int, dataset_uri: str, @pytest.mark.parametrize('num_fewshot', [0, 5]) @pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl']) @pytest.mark.filterwarnings( - r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning', ) @pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') @pytest.mark.gpu @pytest.mark.world_size(2) def test_qa_task_evaluation_opt_tokenizer( - tiny_opt_tokenizer: transformers.AutoTokenizer, - tiny_opt_model: transformers.AutoModelForCausalLM, num_fewshot: int, - dataset_uri: str, tmp_path: Path): + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, + num_fewshot: int, + dataset_uri: str, + tmp_path: Path, +): in_memory_logger = InMemoryLogger( ) # track the logged metrics in the in_memory_logger @@ -2279,7 +2586,8 @@ def test_qa_task_evaluation_opt_tokenizer( evaluator = Evaluator( label='triviaqa', dataloader=dl, - metric_names=['InContextLearningGenerationExactMatchAccuracy']) + metric_names=['InContextLearningGenerationExactMatchAccuracy'], + ) model = HuggingFaceModel( model=tiny_opt_model, tokenizer=tokenizer, @@ -2302,13 +2610,16 @@ def test_qa_task_evaluation_opt_tokenizer( @pytest.mark.gpu @pytest.mark.world_size(2) @pytest.mark.filterwarnings( - r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning', ) @pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning') def test_qa_task_evaluation_with_cot_opt_tokenizer( - tiny_opt_tokenizer: transformers.AutoTokenizer, - tiny_opt_model: transformers.AutoModelForCausalLM, num_fewshot: int, - dataset_uri: str, tmp_path: Path): + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, + num_fewshot: int, + dataset_uri: str, + tmp_path: Path, +): in_memory_logger = InMemoryLogger( ) # track the logged metrics in the in_memory_logger @@ -2337,7 +2648,8 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer( evaluator = Evaluator( label='gsm8k', dataloader=dl, - metric_names=['InContextLearningGenerationExactMatchAccuracy']) + metric_names=['InContextLearningGenerationExactMatchAccuracy'], + ) model = HuggingFaceModel( model=tiny_opt_model, tokenizer=tokenizer, @@ -2360,12 +2672,15 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer( @pytest.mark.gpu @pytest.mark.world_size(2) @pytest.mark.filterwarnings( - r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning', ) -def test_qa_task_evaluation(num_fewshot: int, dataset_uri: str, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tiny_gpt2_model: transformers.AutoModelForCausalLM, - tmp_path: Path): +def test_qa_task_evaluation( + num_fewshot: int, + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tiny_gpt2_model: transformers.AutoModelForCausalLM, + tmp_path: Path, +): in_memory_logger = InMemoryLogger( ) # track the logged metrics in the in_memory_logger @@ -2392,7 +2707,8 @@ def test_qa_task_evaluation(num_fewshot: int, dataset_uri: str, evaluator = Evaluator( label='triviaqa', dataloader=dl, - metric_names=['InContextLearningGenerationExactMatchAccuracy']) + metric_names=['InContextLearningGenerationExactMatchAccuracy'], + ) model = HuggingFaceModel( model=tiny_gpt2_model, @@ -2414,14 +2730,17 @@ def test_qa_task_evaluation(num_fewshot: int, dataset_uri: str, @pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl']) @pytest.mark.parametrize('num_fewshot', [5]) @pytest.mark.filterwarnings( - r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning', ) @pytest.mark.gpu @pytest.mark.world_size(2) def test_qa_task_with_cot_evaluation( - num_fewshot: int, dataset_uri: str, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tiny_gpt2_model: transformers.AutoModelForCausalLM, tmp_path: Path): + num_fewshot: int, + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tiny_gpt2_model: transformers.AutoModelForCausalLM, + tmp_path: Path, +): in_memory_logger = InMemoryLogger( ) # track the logged metrics in the in_memory_logger @@ -2449,7 +2768,8 @@ def test_qa_task_with_cot_evaluation( evaluator = Evaluator( label='gsm8k', dataloader=dl, - metric_names=['InContextLearningGenerationExactMatchAccuracy']) + metric_names=['InContextLearningGenerationExactMatchAccuracy'], + ) model = HuggingFaceModel( model=tiny_gpt2_model, @@ -2471,16 +2791,18 @@ def test_qa_task_with_cot_evaluation( def test_code_eval_requires_envvar(monkeypatch: pytest.MonkeyPatch): monkeypatch.delenv('CODE_EVAL_DEVICE', raising=False) with pytest.raises( - ValueError, - match='Attempting to use InContextLearningCodeEvalAccuracy but.*'): + ValueError, + match='Attempting to use InContextLearningCodeEvalAccuracy but.*', + ): InContextLearningCodeEvalAccuracy().get_client() def test_code_eval_requires_valid_envvar(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv('CODE_EVAL_DEVICE', 'bigchungus') with pytest.raises( - ValueError, - match='Environment variable `CODE_EVAL_DEVICE` must be on.*'): + ValueError, + match='Environment variable `CODE_EVAL_DEVICE` must be on.*', + ): InContextLearningCodeEvalAccuracy().get_client() @@ -2490,13 +2812,17 @@ def test_code_eval_requires_valid_envvar(monkeypatch: pytest.MonkeyPatch): @pytest.mark.gpu @pytest.mark.world_size(2) @pytest.mark.filterwarnings( - r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning', ) def test_code_eval_microbatching( - monkeypatch: pytest.MonkeyPatch, - tiny_opt_tokenizer: transformers.AutoTokenizer, - tiny_opt_model: transformers.AutoModelForCausalLM, num_fewshot: int, - dataset_uri: str, tmp_path: Path, generations_per_sample: int): + monkeypatch: pytest.MonkeyPatch, + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, + num_fewshot: int, + dataset_uri: str, + tmp_path: Path, + generations_per_sample: int, +): monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL') in_memory_logger = InMemoryLogger( @@ -2523,10 +2849,12 @@ def test_code_eval_microbatching( generations_per_sample=generations_per_sample, ) - evaluator = Evaluator(label='humaneval', - dataloader=dl, - metric_names=['InContextLearningCodeEvalAccuracy'], - device_eval_microbatch_size=1) + evaluator = Evaluator( + label='humaneval', + dataloader=dl, + metric_names=['InContextLearningCodeEvalAccuracy'], + device_eval_microbatch_size=1, + ) model = HuggingFaceModel( model=tiny_opt_model, tokenizer=tokenizer, @@ -2550,13 +2878,17 @@ def test_code_eval_microbatching( @pytest.mark.gpu @pytest.mark.world_size(2) @pytest.mark.filterwarnings( - r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning', ) def test_code_eval_sentpiece_evaluation( - monkeypatch: pytest.MonkeyPatch, num_fewshot: int, dataset_uri: str, - tiny_opt_tokenizer: transformers.AutoTokenizer, - tiny_opt_model: transformers.AutoModelForCausalLM, tmp_path: Path, - generations_per_sample: int): + monkeypatch: pytest.MonkeyPatch, + num_fewshot: int, + dataset_uri: str, + tiny_opt_tokenizer: transformers.AutoTokenizer, + tiny_opt_model: transformers.AutoModelForCausalLM, + tmp_path: Path, + generations_per_sample: int, +): monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL') in_memory_logger = InMemoryLogger( @@ -2582,9 +2914,11 @@ def test_code_eval_sentpiece_evaluation( generations_per_sample=generations_per_sample, ) - evaluator = Evaluator(label='humaneval', - dataloader=dl, - metric_names=['InContextLearningCodeEvalAccuracy']) + evaluator = Evaluator( + label='humaneval', + dataloader=dl, + metric_names=['InContextLearningCodeEvalAccuracy'], + ) model = HuggingFaceModel( model=tiny_opt_model, tokenizer=tiny_opt_tokenizer, @@ -2609,13 +2943,17 @@ def test_code_eval_sentpiece_evaluation( @pytest.mark.gpu @pytest.mark.world_size(2) @pytest.mark.filterwarnings( - r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning' + r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning', ) def test_code_eval_task_evaluation( - monkeypatch: pytest.MonkeyPatch, num_fewshot: int, dataset_uri: str, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tiny_gpt2_model: transformers.AutoModelForCausalLM, tmp_path: Path, - generations_per_sample: int): + monkeypatch: pytest.MonkeyPatch, + num_fewshot: int, + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tiny_gpt2_model: transformers.AutoModelForCausalLM, + tmp_path: Path, + generations_per_sample: int, +): monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL') in_memory_logger = InMemoryLogger( @@ -2641,9 +2979,11 @@ def test_code_eval_task_evaluation( generations_per_sample=generations_per_sample, ) - evaluator = Evaluator(label='humaneval', - dataloader=dl, - metric_names=['InContextLearningCodeEvalAccuracy']) + evaluator = Evaluator( + label='humaneval', + dataloader=dl, + metric_names=['InContextLearningCodeEvalAccuracy'], + ) model = HuggingFaceModel( model=tiny_gpt2_model, tokenizer=tiny_gpt2_tokenizer, @@ -2662,9 +3002,11 @@ def test_code_eval_task_evaluation( @pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) -def test_lm_spacing_dataloader(dataset_uri: str, - tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path): +def test_lm_spacing_dataloader( + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, +): local_data = os.path.join(os.path.dirname(__file__), 'local_data') @@ -2672,26 +3014,32 @@ def test_lm_spacing_dataloader(dataset_uri: str, dataset_uri = f'{local_data}/{dataset_uri}' batch_size = 2 seqlen = 512 - dl = get_icl_task_dataloader('language_modeling', - dataset_uri=dataset_uri, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' UNIQUE ', - destination_path=str(tmp_path / 'icl.jsonl')) + dl = get_icl_task_dataloader( + 'language_modeling', + dataset_uri=dataset_uri, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=1, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=' UNIQUE ', + destination_path=str(tmp_path / 'icl.jsonl'), + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright first_batch = next(dl.dataloader._get_iterator()) second_batch = next(dl.dataloader._get_iterator()) - first_batch_text = tokenizer.decode(first_batch['input_ids'][0], - skip_special_tokens=True) - second_batch_text = tokenizer.decode(second_batch['input_ids'][0], - skip_special_tokens=True) + first_batch_text = tokenizer.decode( + first_batch['input_ids'][0], + skip_special_tokens=True, + ) + second_batch_text = tokenizer.decode( + second_batch['input_ids'][0], + skip_special_tokens=True, + ) first_batch_without_last_word = ' '.join(first_batch_text.split(' ')[:-1]) second_batch_without_last_word = ' '.join(second_batch_text.split(' ')[:-1]) @@ -2706,25 +3054,32 @@ def test_lm_spacing_dataloader(dataset_uri: str, @pytest.mark.parametrize('dataset_uri', ['hf://mosaicml/test_dataset']) @pytest.mark.parametrize('num_fewshot', [0, 1]) @pytest.mark.parametrize('prompt_string', ['Complete the voiceline: ', '']) -@pytest.mark.parametrize('hf_loading_vars', [{ - 'split': 'test', - 'name': 'juggernaut', -}]) +@pytest.mark.parametrize( + 'hf_loading_vars', + [{ + 'split': 'test', + 'name': 'juggernaut', + }], +) @pytest.mark.parametrize( 'hf_parsing_map', [None, { 'context': ['context'], - 'continuation': ['continuation'] - }]) + 'continuation': ['continuation'], + }], +) @pytest.mark.filterwarnings( - r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning', ) def test_hf_dataloading_lm_dataloader( - dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path, num_fewshot: int, prompt_string: str, - hf_loading_vars: Dict[str, - str], hf_parsing_map: Optional[Dict[str, - List[str]]]): + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + num_fewshot: int, + prompt_string: str, + hf_loading_vars: Dict[str, str], + hf_parsing_map: Optional[Dict[str, List[str]]], +): tokenizer = tiny_gpt2_tokenizer batch_size = 2 @@ -2742,7 +3097,8 @@ def test_hf_dataloading_lm_dataloader( continuation_delimiter=' ', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map) + hf_parsing_map=hf_parsing_map, + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright batch = next(dl.dataloader._get_iterator()) @@ -2753,13 +3109,15 @@ def test_hf_dataloading_lm_dataloader( assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) assert 'continuation_indices' in batch assert isinstance(batch['continuation_indices'], list) and len( - batch['continuation_indices']) == batch_size + batch['continuation_indices'], + ) == batch_size assert 'mode' in batch assert batch['mode'] == 'icl_task' min_idx = min(batch['continuation_indices'][0]).item() max_idx = max(batch['continuation_indices'][0]).item() - assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + - 1]) == ' and me.' + assert tokenizer.decode( + batch['input_ids'][0][min_idx:max_idx + 1], + ) == ' and me.' decoded_batch = [ tokenizer.decode(row[row != tokenizer.eos_token_id]) @@ -2773,21 +3131,32 @@ def test_hf_dataloading_lm_dataloader( @pytest.mark.parametrize('dataset_uri', ['hf://mosaicml/test_dataset']) @pytest.mark.parametrize('num_fewshot', [0, 1]) @pytest.mark.parametrize('prompt_string', ['What spell does this invoke? ', '']) -@pytest.mark.parametrize('hf_loading_vars', [{ - 'split': 'test', - 'name': 'invoker', -}]) -@pytest.mark.parametrize('hf_parsing_map', [{ - 'context': ['quas', 'wex', 'exort'], - 'answer': ['spell'] -}]) +@pytest.mark.parametrize( + 'hf_loading_vars', + [{ + 'split': 'test', + 'name': 'invoker', + }], +) +@pytest.mark.parametrize( + 'hf_parsing_map', + [{ + 'context': ['quas', 'wex', 'exort'], + 'answer': ['spell'], + }], +) @pytest.mark.filterwarnings( - r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning' + r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning', ) def test_hf_dataloading_custom_parsing( - dataset_uri: str, tiny_gpt2_tokenizer: transformers.AutoTokenizer, - tmp_path: Path, num_fewshot: int, prompt_string: str, - hf_loading_vars: Dict[str, str], hf_parsing_map: Dict[str, List[str]]): + dataset_uri: str, + tiny_gpt2_tokenizer: transformers.AutoTokenizer, + tmp_path: Path, + num_fewshot: int, + prompt_string: str, + hf_loading_vars: Dict[str, str], + hf_parsing_map: Dict[str, List[str]], +): tokenizer = tiny_gpt2_tokenizer batch_size = 2 @@ -2810,15 +3179,18 @@ def test_hf_dataloading_custom_parsing( continuation_delimiter='\nSpell:', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map) + hf_parsing_map=hf_parsing_map, + ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright batch = next(dl.dataloader._get_iterator()) - assert tuple(batch['input_ids'].shape) == (batch_size, - seqlen - maximum_answer_length) - assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - - maximum_answer_length) + assert tuple( + batch['input_ids'].shape, + ) == (batch_size, seqlen - maximum_answer_length) + assert tuple( + batch['attention_mask'].shape, + ) == (batch_size, seqlen - maximum_answer_length) assert batch['mode'] == 'generate' # the maximum generation length from the small test data assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length @@ -2826,16 +3198,20 @@ def test_hf_dataloading_custom_parsing( decoded_batch = tokenizer.batch_decode(batch['input_ids']) assert all( - item.count('Orbs: ') == num_fewshot + 1 for item in decoded_batch) + item.count('Orbs: ') == num_fewshot + 1 for item in decoded_batch + ) assert all( - item.count('\nSpell:') == num_fewshot + 1 for item in decoded_batch) + item.count('\nSpell:') == num_fewshot + 1 for item in decoded_batch + ) if len(prompt_string) > 0: assert all( item.count('What spell does this invoke? ') == 1 - for item in decoded_batch) + for item in decoded_batch + ) assert all( - set(found) == set(expected) for found, expected in zip( - batch['labels'], [['defeaning blast'], ['cold snap']])) + set(found) == set(expected) for found, expected in + zip(batch['labels'], [['defeaning blast'], ['cold snap']]) + ) assert decoded_batch[0].endswith('Orbs: quas wex exort\nSpell:') assert decoded_batch[1].endswith('Orbs: quas quas quas\nSpell:') diff --git a/tests/eval/test_nlp_metrics.py b/tests/eval/test_nlp_metrics.py index 344d642715..e07be4d863 100644 --- a/tests/eval/test_nlp_metrics.py +++ b/tests/eval/test_nlp_metrics.py @@ -9,12 +9,15 @@ from llmfoundry.eval.metrics import ( InContextLearningCodeEvalAccuracy, - InContextLearningGenerationExactMatchAccuracy, InContextLearningLMAccuracy, - InContextLearningMultipleChoiceAccuracy) + InContextLearningGenerationExactMatchAccuracy, + InContextLearningLMAccuracy, + InContextLearningMultipleChoiceAccuracy, +) def test_in_context_learning_lm_accuracy( - tiny_gpt2_tokenizer: transformers.AutoTokenizer): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, +): contexts = ['The dog is', 'I love to eat', 'I hate', 'The weather is'] continuations = [' furry', ' pie', ' long lines', ' snowy'] pad = tiny_gpt2_tokenizer.pad_token_id @@ -23,8 +26,9 @@ def test_in_context_learning_lm_accuracy( tiny_gpt2_tokenizer(continuation)['input_ids'] for context, continuation in zip(contexts, continuations) ] - inputs = torch.tensor( - [input + [pad] * (2048 - len(input)) for input in inputs]) + inputs = torch.tensor([ + input + [pad] * (2048 - len(input)) for input in inputs + ]) cont_idxs = [] for context, continuation in zip(contexts, continuations): @@ -35,7 +39,7 @@ def test_in_context_learning_lm_accuracy( batch = { 'continuation_indices': cont_idxs, 'labels': inputs.roll(-1), - 'input_ids': inputs + 'input_ids': inputs, } logits = torch.nn.functional.one_hot(inputs.roll(-1), num_classes=pad + 1).float() * 100 @@ -50,8 +54,9 @@ def test_in_context_learning_lm_accuracy( def test_in_context_learning_qa_accuracy(): outputs = [ - 'Correct but then some more text', 'Incorrect', - ' the CORREct with weird casing and spacing' + 'Correct but then some more text', + 'Incorrect', + ' the CORREct with weird casing and spacing', ] labels = [['Correct'], ['blah', 'blah2'], ['blah', 'correct']] batch = {'cot_delimiter': '', 'labels': labels} @@ -66,14 +71,14 @@ def test_in_context_learning_qa_cot_accuracy(): 'chain of thought ### Correct but then some more text\n\nanother chain of thought ### Incorrect answer this time', 'Incorrect', 'chain of thought ### the CORREct with weird casing and spacing', - 'incorrect chain of thought delimiter ## Correct but wrong delimiter' + 'incorrect chain of thought delimiter ## Correct but wrong delimiter', ] labels = [['Correct'], ['blah', 'blah2'], ['blah', 'correct'], ['correct']] batch = { 'cot_delimiter': ' ### ', 'labels': labels, 'do_normalization': True, - 'stopping_criteria': '\n\n' + 'stopping_criteria': '\n\n', } metric = InContextLearningGenerationExactMatchAccuracy() metric.update(batch, outputs, labels) @@ -82,18 +87,21 @@ def test_in_context_learning_qa_cot_accuracy(): def test_in_context_learning_code_eval_accuracy( - monkeypatch: pytest.MonkeyPatch): + monkeypatch: pytest.MonkeyPatch, +): outputs = [ ' return 1 if n <= 1 else fib(n - 1) + fib(n - 1)', # incorrect ' if n <= 1:\n return 1\n return fib(n-1) + fib(n-2)', # incorrect spacing ' return n * 2', # correct ' return 2*n', # correct ' return n + 2', # incorrect - ' return n + 1' + ' return n + 1', ] # correct labels = [] prompts = [ - 'def fib(n):\n', 'def multiply_by_two(n):\n', 'def add_one(n):\n' + 'def fib(n):\n', + 'def multiply_by_two(n):\n', + 'def add_one(n):\n', ] entry_points = ['fib', 'multiply_by_two', 'add_one'] test_inputs = [['(1,)', '(2,)', '(4,)'], ['(1,)', '(2,)', '(4,)'], @@ -109,11 +117,14 @@ def repeat(values: List[Any]): transformers = pytest.importorskip('transformers') tokenizer = transformers.AutoTokenizer.from_pretrained( - 'mosaicml/mpt-7b') # type: ignore reportUnboundVariable + 'mosaicml/mpt-7b', + ) # type: ignore reportUnboundVariable tokenizer.pad_token = tokenizer.eos_token - input_ids = tokenizer.batch_encode_plus(repeat(prompts), - return_tensors='pt', - padding=True)['input_ids'] + input_ids = tokenizer.batch_encode_plus( + repeat(prompts), + return_tensors='pt', + padding=True, + )['input_ids'] batch = { # This tests deterministic beam search rather than sampling 'input_ids': input_ids, @@ -142,14 +153,19 @@ def repeat(values: List[Any]): def test_in_context_learning_mc_accuracy( - tiny_gpt2_tokenizer: transformers.AutoTokenizer): + tiny_gpt2_tokenizer: transformers.AutoTokenizer, +): contexts = [ - 'Q: How do you cook a cake?', 'Q: How do you cook a cake?', - 'Q: How old is the earth?', 'Q: How old is the earth?' + 'Q: How do you cook a cake?', + 'Q: How do you cook a cake?', + 'Q: How old is the earth?', + 'Q: How old is the earth?', ] continuations = [ - ' A: turn on the oven', ' A: do a backflip', ' A: 2 minutes', - ' A: 4.5 billion years' + ' A: turn on the oven', + ' A: do a backflip', + ' A: 2 minutes', + ' A: 4.5 billion years', ] gold_indices = [0, 1] choice_groupings = [(0, 2), (2, 4)] @@ -159,8 +175,9 @@ def test_in_context_learning_mc_accuracy( tiny_gpt2_tokenizer(continuation)['input_ids'] for context, continuation in zip(contexts, continuations) ] - inputs = torch.tensor( - [input + [pad] * (2048 - len(input)) for input in inputs]) + inputs = torch.tensor([ + input + [pad] * (2048 - len(input)) for input in inputs + ]) attention_mask = ~(inputs == pad) cont_idxs = [] @@ -175,7 +192,7 @@ def test_in_context_learning_mc_accuracy( 'input_ids': inputs, 'attention_mask': attention_mask, 'gold_indices': gold_indices, - 'choice_groupings': choice_groupings + 'choice_groupings': choice_groupings, } logits = torch.nn.functional.one_hot(inputs.roll(-1), num_classes=pad + 1).float() diff --git a/tests/fixtures/data.py b/tests/fixtures/data.py index 9ba053ffe8..cd85bd2603 100644 --- a/tests/fixtures/data.py +++ b/tests/fixtures/data.py @@ -27,10 +27,12 @@ def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path: @fixture @patch('os.cpu_count', MagicMock(return_value=1)) -def tiny_ft_dataloader(tiny_ft_dataset_path: Path, - mpt_tokenizer: PreTrainedTokenizerBase, - max_seq_len: int = 128, - device_batch_size: int = 1) -> DataLoader: +def tiny_ft_dataloader( + tiny_ft_dataset_path: Path, + mpt_tokenizer: PreTrainedTokenizerBase, + max_seq_len: int = 128, + device_batch_size: int = 1, +) -> DataLoader: dataloader_cfg = DictConfig({ 'name': 'finetuning', 'dataset': { @@ -47,7 +49,7 @@ def tiny_ft_dataloader(tiny_ft_dataset_path: Path, 'pin_memory': False, 'prefetch_factor': 2, 'persistent_workers': False, - 'timeout': 0 + 'timeout': 0, }) dataloader = build_finetuning_dataloader( diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 616d66085c..50ad4497d5 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -30,7 +30,7 @@ def mpt_tokenizer(): @fixture def build_tiny_mpt( - mpt_tokenizer: PreTrainedTokenizerBase + mpt_tokenizer: PreTrainedTokenizerBase, ) -> Callable[..., ComposerMPTCausalLM]: def build(**kwargs: Any) -> ComposerMPTCausalLM: @@ -51,7 +51,7 @@ def build(**kwargs: Any) -> ComposerMPTCausalLM: @fixture def build_tiny_hf_mpt( - mpt_tokenizer: PreTrainedTokenizerBase + mpt_tokenizer: PreTrainedTokenizerBase, ) -> Callable[..., ComposerHFCausalLM]: def build(**kwargs: Any) -> ComposerHFCausalLM: @@ -93,7 +93,7 @@ def tiny_gpt2_config_helper(): 'n_embd': 2, 'n_head': 2, 'n_layer': 2, - 'vocab_size': 50258 # 50257 + 1 for pad token + 'vocab_size': 50258, # 50257 + 1 for pad token } return transformers.AutoConfig.from_pretrained('gpt2', **tiny_overrides) @@ -130,7 +130,9 @@ def tiny_llama_tokenizer_helper(): transformers = pytest.importorskip('transformers') hf_tokenizer = transformers.AutoTokenizer.from_pretrained( - 'huggyllama/llama-7b', use_fast=False) + 'huggyllama/llama-7b', + use_fast=False, + ) return hf_tokenizer @@ -148,7 +150,8 @@ def tiny_opt_tokenizer_helper(): transformers = pytest.importorskip('transformers') hf_tokenizer = transformers.AutoTokenizer.from_pretrained( - 'facebook/opt-125m') + 'facebook/opt-125m', + ) hf_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) return hf_tokenizer @@ -181,10 +184,12 @@ def tiny_opt_config_helper(): 'n_embd': 2, 'n_head': 2, 'n_layer': 2, - 'vocab_size': 50272 + 'vocab_size': 50272, } - return transformers.AutoConfig.from_pretrained('facebook/opt-125m', - **tiny_overrides) + return transformers.AutoConfig.from_pretrained( + 'facebook/opt-125m', + **tiny_overrides, + ) @pytest.fixture diff --git a/tests/horrible_strings.py b/tests/horrible_strings.py index 31cd55cb9b..13c3978cb6 100644 --- a/tests/horrible_strings.py +++ b/tests/horrible_strings.py @@ -1,14 +1,16 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -# taken from https://github.com/explosion/spaCy/blob/8f0d6b0a8c42e4852bf6e24cdf629043f2f39361/spacy/tests/tokenizer/test_naughty_strings.py#L7 +# ruff: noqa: PLE2502 + +# Taken from https://github.com/explosion/spaCy/blob/8f0d6b0a8c42e4852bf6e24cdf629043f2f39361/spacy/tests/tokenizer/test_naughty_strings.py#L7 HORRIBLE_STRINGS = [ # ASCII punctuation r",./;'[]\-=", r'<>?:"{}|_+', r'!@#$%^&*()`~"', # Unicode additional control characters, byte order marks - r"­؀؁؂؃؄؅؜۝܏᠎​‌‍‎‏‪", + r"­؀؁؂؃؄؅؜۝܏᠎\u200b‌‍‎‏‪", r"￾", # Unicode Symbols r"Ω≈ç√∫˜µ≤≥÷", diff --git a/tests/models/hf/test_fsdp_weight_tying.py b/tests/models/hf/test_fsdp_weight_tying.py index 712e515653..4b76996ba1 100644 --- a/tests/models/hf/test_fsdp_weight_tying.py +++ b/tests/models/hf/test_fsdp_weight_tying.py @@ -14,23 +14,30 @@ @pytest.mark.world_size(2) @pytest.mark.gpu -@pytest.mark.parametrize('peft_config', [ - None, { - 'peft_type': 'LORA', - 'task_type': 'CAUSAL_LM', - 'lora_alpha': 32, - 'lora_dropout': 0.05, - 'r': 16, - 'target_modules': [ - 'q_proj', - 'k_proj', - 'v_proj', - ], - } -]) +@pytest.mark.parametrize( + 'peft_config', + [ + None, + { + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'lora_alpha': 32, + 'lora_dropout': 0.05, + 'r': 16, + 'target_modules': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + }, + ], +) @pytest.mark.parametrize('init_device', ['cpu', 'mixed', 'meta']) -def test_fsdp_weight_tying(peft_config: Optional[dict], tmp_path: pathlib.Path, - init_device: str): +def test_fsdp_weight_tying( + peft_config: Optional[dict], + tmp_path: pathlib.Path, + init_device: str, +): model_cfg = { 'name': 'hf_causal_lm', 'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf', diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py index e79756aba3..191bce48f7 100644 --- a/tests/models/hf/test_hf_config.py +++ b/tests/models/hf/test_hf_config.py @@ -20,7 +20,8 @@ def test_remote_code_false_mpt( - conf_path: str = 'scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml'): + conf_path: str = 'scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml', +): with open(conf_path) as f: test_cfg = om.load(f) @@ -36,16 +37,18 @@ def test_remote_code_false_mpt( test_cfg.device = device test_cfg.precision = 'fp16' - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = om.to_container( + test_cfg.tokenizer, + resolve=True, + ) # type: ignore tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) with pytest.raises( - ValueError, - match='trust_remote_code must be set to True for MPT models.'): + ValueError, + match='trust_remote_code must be set to True for MPT models.', + ): _ = build_composer_model( name=test_cfg.model.name, cfg=test_cfg.model, @@ -56,17 +59,19 @@ def test_remote_code_false_mpt( @pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_tie_weights(tie_word_embeddings: bool): # Test that the tie_weights function sets lm_head correctly - hf_config = MPTConfig(init_device='cpu', - d_model=128, - n_heads=4, - n_layers=2, - expansion_ratio=2, - max_seq_len=2048, - attn_config={ - 'attn_impl': 'torch', - }, - no_bias=True, - tie_word_embeddings=tie_word_embeddings) + hf_config = MPTConfig( + init_device='cpu', + d_model=128, + n_heads=4, + n_layers=2, + expansion_ratio=2, + max_seq_len=2048, + attn_config={ + 'attn_impl': 'torch', + }, + no_bias=True, + tie_word_embeddings=tie_word_embeddings, + ) mpt = MPTForCausalLM(hf_config) @@ -78,40 +83,49 @@ def test_tie_weights(tie_word_embeddings: bool): assert mpt.lm_head is not None -@pytest.mark.parametrize('model_cfg_overrides', [ - { - 'max_seq_len': 1024 - }, - { - 'attn_config': { - 'attn_impl': 'flash', - } - }, - { - 'init_config': { - 'emb_init_std': 5 - } - }, - { - 'max_seq_len': 1024, - 'attn_config': { - 'attn_impl': 'flash', +@pytest.mark.parametrize( + 'model_cfg_overrides', + [ + { + 'max_seq_len': 1024, + }, + { + 'attn_config': { + 'attn_impl': 'flash', + }, + }, + { + 'init_config': { + 'emb_init_std': 5, + }, }, - 'init_config': { - 'emb_init_std': 5 + { + 'max_seq_len': 1024, + 'attn_config': { + 'attn_impl': 'flash', + }, + 'init_config': { + 'emb_init_std': 5, + }, }, - }, - pytest.param({'msl': 1024}, - marks=pytest.mark.xfail(reason='"msl" is a ValueError', - strict=True)), - pytest.param({'attn_config': { - 'attn_iml': 'flash' - }}, - marks=pytest.mark.xfail(reason='"attn_impl" mispelled', - strict=True)), -]) -@patch('llmfoundry.models.layers.attention.is_flash_v2_installed', - new=Mock(return_value=True)) + pytest.param({'msl': 1024}, + marks=pytest.mark.xfail( + reason='"msl" is a ValueError', + strict=True, + )), + pytest.param({'attn_config': { + 'attn_iml': 'flash', + }}, + marks=pytest.mark.xfail( + reason='"attn_impl" mispelled', + strict=True, + )), + ], +) +@patch( + 'llmfoundry.models.layers.attention.is_flash_v2_installed', + new=Mock(return_value=True), +) def test_hf_config_override( model_cfg_overrides: Dict[str, Any], conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', @@ -132,9 +146,10 @@ def test_hf_config_override( test_cfg.precision = 'fp16' test_cfg.model.attn_config = {'attn_impl': 'torch', 'alibi': True} - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = om.to_container( + test_cfg.tokenizer, + resolve=True, + ) # type: ignore tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) @@ -176,8 +191,10 @@ def test_hf_config_override( assert getattr(hf_model.config, k) == v -@pytest.mark.skipif('HUGGING_FACE_HUB_TOKEN' not in os.environ, - reason='CI does not have access to llama2') +@pytest.mark.skipif( + 'HUGGING_FACE_HUB_TOKEN' not in os.environ, + reason='CI does not have access to llama2', +) def test_rope_scaling_override(): model_cfg = { 'name': 'hf_causal_lm', @@ -188,8 +205,8 @@ def test_rope_scaling_override(): 'intermediate_size': 64, 'rope_scaling': { 'type': 'dynamic', - 'factor': 0.5 - } + 'factor': 0.5, + }, }, 'use_auth_token': True, 'pretrained': False, @@ -207,8 +224,10 @@ def test_rope_scaling_override(): assert model.config.rope_scaling == {'type': 'dynamic', 'factor': 0.5} -@pytest.mark.skipif('HUGGING_FACE_HUB_TOKEN' not in os.environ, - reason='CI does not have access to Dbrx') +@pytest.mark.skipif( + 'HUGGING_FACE_HUB_TOKEN' not in os.environ, + reason='CI does not have access to Dbrx', +) def test_nested_override(): model_cfg = { 'name': 'hf_causal_lm', @@ -216,7 +235,7 @@ def test_nested_override(): 'config_overrides': { 'ffn_config': { 'ffn_hidden_size': 500, - } + }, }, 'use_auth_token': True, 'pretrained': False, diff --git a/tests/models/hf/test_hf_mpt_gen.py b/tests/models/hf/test_hf_mpt_gen.py index 917e970852..ddd77cf8e0 100644 --- a/tests/models/hf/test_hf_mpt_gen.py +++ b/tests/models/hf/test_hf_mpt_gen.py @@ -24,19 +24,23 @@ def test_init_hfhub_mpt( pytest.skip(f'{attn_impl=} not implemented for {device=}.') composer_device = get_device(device) - model = build_tiny_hf_mpt(attn_config={ - 'attn_impl': attn_impl, - 'attn_uses_sequence_id': False, - }) + model = build_tiny_hf_mpt( + attn_config={ + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': False, + }, + ) model = composer_device.module_to_device(model) model.eval() - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context( + 'amp_bf16' if composer_device.name == 'gpu' else 'fp32', + ): _ = model.generate( composer_device.tensor_to_device( - mpt_tokenizer('hello', return_tensors='pt')['input_ids']), + mpt_tokenizer('hello', return_tensors='pt')['input_ids'], + ), max_new_tokens=2, ) @@ -45,7 +49,9 @@ def test_init_hfhub_mpt_cpu( build_tiny_hf_mpt: Callable[..., ComposerHFCausalLM], mpt_tokenizer: PreTrainedTokenizerBase, ): - test_init_hfhub_mpt(device='cpu', - attn_impl='torch', - build_tiny_hf_mpt=build_tiny_hf_mpt, - mpt_tokenizer=mpt_tokenizer) + test_init_hfhub_mpt( + device='cpu', + attn_impl='torch', + build_tiny_hf_mpt=build_tiny_hf_mpt, + mpt_tokenizer=mpt_tokenizer, + ) diff --git a/tests/models/hf/test_hf_peft_wrapping.py b/tests/models/hf/test_hf_peft_wrapping.py index 7fe886ffe3..683a0ba0cd 100644 --- a/tests/models/hf/test_hf_peft_wrapping.py +++ b/tests/models/hf/test_hf_peft_wrapping.py @@ -17,11 +17,15 @@ def test_peft_wraps(): - mpt_cfg = transformers.AutoConfig.from_pretrained('mosaicml/mpt-7b', - n_layers=2, - trust_remote_code=True) - mpt = transformers.AutoModelForCausalLM.from_config(mpt_cfg, - trust_remote_code=True) + mpt_cfg = transformers.AutoConfig.from_pretrained( + 'mosaicml/mpt-7b', + n_layers=2, + trust_remote_code=True, + ) + mpt = transformers.AutoModelForCausalLM.from_config( + mpt_cfg, + trust_remote_code=True, + ) mpt = get_peft_model(mpt, LoraConfig()) prepare_hf_model_for_fsdp(mpt, 'cpu') @@ -35,22 +39,28 @@ def test_peft_wraps(): @pytest.mark.world_size(2) @pytest.mark.gpu -@pytest.mark.parametrize('peft_config', [{ - 'peft_type': 'LORA', - 'task_type': 'CAUSAL_LM', - 'lora_alpha': 32, - 'lora_dropout': 0.05, - 'r': 16, - 'target_modules': [ - 'q_proj', - 'k_proj', - 'v_proj', - ], -}]) +@pytest.mark.parametrize( + 'peft_config', + [{ + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'lora_alpha': 32, + 'lora_dropout': 0.05, + 'r': 16, + 'target_modules': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + }], +) @pytest.mark.parametrize('init_device', ['mixed']) @patch('torch.nn.init.kaiming_uniform_', lambda w, a: torch.nn.init.ones_(w)) -def test_lora_mixed_init(peft_config: Optional[dict], tmp_path: pathlib.Path, - init_device: str): +def test_lora_mixed_init( + peft_config: Optional[dict], + tmp_path: pathlib.Path, + init_device: str, +): model_cfg = { 'name': 'hf_causal_lm', 'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf', diff --git a/tests/models/hf/test_hf_v_mpt.py b/tests/models/hf/test_hf_v_mpt.py index 82b64ce80c..042a18bf76 100644 --- a/tests/models/hf/test_hf_v_mpt.py +++ b/tests/models/hf/test_hf_v_mpt.py @@ -12,21 +12,32 @@ @pytest.mark.gpu -@pytest.mark.parametrize('attn_impl,dropout,alibi,mask_val,no_attn_mask', [ - ('flash', 0.0, False, 1, False), - ('flash', 0.1, False, 1, False), - ('torch', 0.0, False, 1, False), - ('torch', 0.0, False, 0, False), - ('flash', 0.0, False, None, True), - ('torch', 0.0, False, None, True), -]) -def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, - mask_val: Optional[int], no_attn_mask: bool): +@pytest.mark.parametrize( + 'attn_impl,dropout,alibi,mask_val,no_attn_mask', + [ + ('flash', 0.0, False, 1, False), + ('flash', 0.1, False, 1, False), + ('torch', 0.0, False, 1, False), + ('torch', 0.0, False, 0, False), + ('flash', 0.0, False, None, True), + ('torch', 0.0, False, None, True), + ], +) +def test_compare_hf_v_mpt( + attn_impl: str, + dropout: float, + alibi: bool, + mask_val: Optional[int], + no_attn_mask: bool, +): warnings.filterwarnings( action='ignore', - message='Torchmetrics v0.9 introduced a new argument class property') - warnings.filterwarnings(action='ignore', - message='Using Fused Cross Entropy Loss.') + message='Torchmetrics v0.9 introduced a new argument class property', + ) + warnings.filterwarnings( + action='ignore', + message='Using Fused Cross Entropy Loss.', + ) conf_path = 'scripts/train/yamls/pretrain/mpt-125m.yaml' # set cfg path batch_size = 2 # set batch size @@ -43,10 +54,10 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, 'n_layer': 2, 'n_embd': 64, 'n_head': 8, - } + }, }, 'tokenizer': { - 'name': 'gpt2' + 'name': 'gpt2', }, }) @@ -124,22 +135,25 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, # generate random input branch batch = {} - batch['input_ids'] = torch.randint(low=0, - high=model_cfg.vocab_size, - size=(batch_size, - model_cfg.max_seq_len)).to(device) - batch['labels'] = torch.randint(low=0, - high=model_cfg.vocab_size, - size=(batch_size, - model_cfg.max_seq_len)).to(device) + batch['input_ids'] = torch.randint( + low=0, + high=model_cfg.vocab_size, + size=(batch_size, model_cfg.max_seq_len), + ).to(device) + batch['labels'] = torch.randint( + low=0, + high=model_cfg.vocab_size, + size=(batch_size, model_cfg.max_seq_len), + ).to(device) kpm = None if no_attn_mask: if 'attention_mask' in batch.keys(): _ = batch.pop('attention_mask') else: - batch['attention_mask'] = torch.ones(size=(batch_size, - model_cfg.max_seq_len), - dtype=torch.int64).to(device) + batch['attention_mask'] = torch.ones( + size=(batch_size, model_cfg.max_seq_len), + dtype=torch.int64, + ).to(device) # mask out some tokens assert mask_val is not None batch['attention_mask'][:, model_cfg.max_seq_len // 2:] = mask_val @@ -168,7 +182,10 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, hf_keys_ignore = ['.attn.masked_bias', '.attn.bias', 'lm_head'] # HF params which need to be transposed _transpose = [ - '.attn.c_attn.', '.attn.c_proj.', '.mlp.c_fc.', '.mlp.c_proj.' + '.attn.c_attn.', + '.attn.c_proj.', + '.mlp.c_fc.', + '.mlp.c_proj.', ] # HF keys which need to be replaced by the associated value hf_2_mosaic_key_mods = { @@ -215,7 +232,9 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, print(f'{hf_model_fwd = }\n{model_fwd = }') # given dropout seeded the same way, the mean of the outputs is extremely similar - assert hf_model_fwd.mean().allclose(model_fwd.mean(), - rtol=1e-04, - atol=1e-06) + assert hf_model_fwd.mean().allclose( + model_fwd.mean(), + rtol=1e-04, + atol=1e-06, + ) assert hf_model_fwd.allclose(model_fwd, rtol=1e-02, atol=1e-02) diff --git a/tests/models/inference_api_wrapper/test_fmapi.py b/tests/models/inference_api_wrapper/test_fmapi.py index bde2c90d36..72c41c2ebe 100644 --- a/tests/models/inference_api_wrapper/test_fmapi.py +++ b/tests/models/inference_api_wrapper/test_fmapi.py @@ -8,8 +8,10 @@ import transformers from omegaconf import DictConfig, ListConfig -from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper, - FMAPIChatAPIEvalWrapper) +from llmfoundry.models.inference_api_wrapper import ( + FMAPICasualLMEvalWrapper, + FMAPIChatAPIEvalWrapper, +) from llmfoundry.models.inference_api_wrapper.fmapi import FMAPIEvalInterface from llmfoundry.utils.builders import build_icl_evaluators @@ -29,9 +31,9 @@ def load_icl_config(): 'continuation_delimiter': '\nAnswer: ', 'has_categories': - True - }) - ]) + True, + }), + ]), }) @@ -94,32 +96,38 @@ def test_casual_fmapi_wrapper(tmp_path: str): _ = pytest.importorskip('openai') tokenizer = transformers.AutoTokenizer.from_pretrained( - 'mosaicml/mpt-7b-8k-instruct') - model = FMAPICasualLMEvalWrapper(om_model_config=DictConfig({ - 'local': True, - 'name': 'mosaicml/mpt-7b-8k-instruct' - }), - tokenizer=tokenizer) + 'mosaicml/mpt-7b-8k-instruct', + ) + model = FMAPICasualLMEvalWrapper( + om_model_config=DictConfig({ + 'local': True, + 'name': 'mosaicml/mpt-7b-8k-instruct', + }), + tokenizer=tokenizer, + ) with patch.object(model, 'client') as mock: mock.completions.create = mock_create task_cfg = load_icl_config() - evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks, - tokenizer, - 1024, - 2, - destination_dir=str(tmp_path)) + evaluators, _ = build_icl_evaluators( + task_cfg.icl_tasks, + tokenizer, + 1024, + 2, + destination_dir=str(tmp_path), + ) batch = next(evaluators[0].dataloader.dataloader.__iter__()) result = model.eval_forward(batch) model.update_metric( batch, result, - metric=model.get_metrics() - ['InContextLearningLMAccuracy']) # pyright: ignore - acc = model.get_metrics( - )['InContextLearningLMAccuracy'].compute( # pyright: ignore + metric=model.get_metrics()['InContextLearningLMAccuracy'], ) # pyright: ignore + acc = model.get_metrics( + )['InContextLearningLMAccuracy' + ].compute( # pyright: ignore + ) # pyright: ignore assert acc == 0.5 @@ -128,31 +136,37 @@ def test_chat_fmapi_wrapper(tmp_path: str): _ = pytest.importorskip('openai') tokenizer = transformers.AutoTokenizer.from_pretrained( - 'mosaicml/mpt-7b-8k-instruct') - chatmodel = FMAPIChatAPIEvalWrapper(om_model_config=DictConfig({ - 'local': True, - 'name': 'mosaicml/mpt-7b-8k-instruct' - }), - tokenizer=tokenizer) + 'mosaicml/mpt-7b-8k-instruct', + ) + chatmodel = FMAPIChatAPIEvalWrapper( + om_model_config=DictConfig({ + 'local': True, + 'name': 'mosaicml/mpt-7b-8k-instruct', + }), + tokenizer=tokenizer, + ) with patch.object(chatmodel, 'client') as mock: mock.chat.completions.create.return_value = MockChatCompletion( - 'Treason!') + 'Treason!', + ) task_cfg = load_icl_config() - evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks, - tokenizer, - 1024, - 2, - destination_dir=str(tmp_path)) + evaluators, _ = build_icl_evaluators( + task_cfg.icl_tasks, + tokenizer, + 1024, + 2, + destination_dir=str(tmp_path), + ) batch = next(evaluators[0].dataloader.dataloader.__iter__()) result = chatmodel.eval_forward(batch) chatmodel.update_metric( batch, result, - metric=chatmodel.get_metrics() - ['InContextLearningLMAccuracy']) # pyright: ignore + metric=chatmodel.get_metrics()['InContextLearningLMAccuracy'], + ) # pyright: ignore acc = chatmodel.get_metrics( )['InContextLearningLMAccuracy'].compute( # pyright: ignore ) diff --git a/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py b/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py index 7ecb61aa43..acc3cb9622 100644 --- a/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py +++ b/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py @@ -8,8 +8,10 @@ import pytest from omegaconf import DictConfig, ListConfig -from llmfoundry.models.inference_api_wrapper import (OpenAICausalLMEvalWrapper, - OpenAIChatAPIEvalWrapper) +from llmfoundry.models.inference_api_wrapper import ( + OpenAICausalLMEvalWrapper, + OpenAIChatAPIEvalWrapper, +) from llmfoundry.tokenizers import TiktokenTokenizerWrapper from llmfoundry.utils.builders import build_icl_evaluators @@ -35,9 +37,9 @@ def load_icl_config(): 'continuation_delimiter': '\nAnswer: ', 'has_categories': - True - }) - ]) + True, + }), + ]), }) @@ -97,30 +99,37 @@ def test_openai_api_eval_wrapper(tmp_path: str, openai_api_key_env_var: str): _ = pytest.importorskip('openai') model_name = 'davinci' - tokenizer = TiktokenTokenizerWrapper(model_name=model_name, - pad_token='<|endoftext|>') - model = OpenAICausalLMEvalWrapper(om_model_config=DictConfig( - {'version': model_name}), - tokenizer=tokenizer) + tokenizer = TiktokenTokenizerWrapper( + model_name=model_name, + pad_token='<|endoftext|>', + ) + model = OpenAICausalLMEvalWrapper( + om_model_config=DictConfig({'version': model_name}), + tokenizer=tokenizer, + ) with patch.object(model, 'client') as mock: mock.completions.create = mock_create task_cfg = load_icl_config() - evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks, - tokenizer, - 1024, - 2, - destination_dir=str(tmp_path)) + evaluators, _ = build_icl_evaluators( + task_cfg.icl_tasks, + tokenizer, + 1024, + 2, + destination_dir=str(tmp_path), + ) batch = next(evaluators[0].dataloader.dataloader.__iter__()) result = model.eval_forward(batch) - model.update_metric(batch, - result, - metric=model.get_metrics() - ['InContextLearningLMAccuracy']) # pyright: ignore - acc = model.get_metrics( - )['InContextLearningLMAccuracy'].compute( # pyright: ignore + model.update_metric( + batch, + result, + metric=model.get_metrics()['InContextLearningLMAccuracy'], ) # pyright: ignore + acc = model.get_metrics( + )['InContextLearningLMAccuracy' + ].compute( # pyright: ignore + ) # pyright: ignore assert acc == 0.5 @@ -128,29 +137,35 @@ def test_chat_api_eval_wrapper(tmp_path: str, openai_api_key_env_var: str): _ = pytest.importorskip('openai') model_name = 'gpt-3.5-turbo' - tokenizer = TiktokenTokenizerWrapper(model_name=model_name, - pad_token='<|endoftext|>') - chatmodel = OpenAIChatAPIEvalWrapper(om_model_config=DictConfig( - {'version': model_name}), - tokenizer=tokenizer) + tokenizer = TiktokenTokenizerWrapper( + model_name=model_name, + pad_token='<|endoftext|>', + ) + chatmodel = OpenAIChatAPIEvalWrapper( + om_model_config=DictConfig({'version': model_name}), + tokenizer=tokenizer, + ) with patch.object(chatmodel, 'client') as mock: mock.chat.completions.create.return_value = MockChatCompletion( - 'Treason!') + 'Treason!', + ) task_cfg = load_icl_config() - evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks, - tokenizer, - 1024, - 2, - destination_dir=str(tmp_path)) + evaluators, _ = build_icl_evaluators( + task_cfg.icl_tasks, + tokenizer, + 1024, + 2, + destination_dir=str(tmp_path), + ) batch = next(evaluators[0].dataloader.dataloader.__iter__()) result = chatmodel.eval_forward(batch) chatmodel.update_metric( batch, result, - metric=chatmodel.get_metrics() - ['InContextLearningLMAccuracy']) # pyright: ignore + metric=chatmodel.get_metrics()['InContextLearningLMAccuracy'], + ) # pyright: ignore acc = chatmodel.get_metrics( )['InContextLearningLMAccuracy'].compute( # pyright: ignore ) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index c8e7ec3e67..3a561b3d3c 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -13,8 +13,10 @@ import torch.optim as optim from torch.distributed._tensor import DTensor, Placement, Replicate, Shard from torch.distributed._tensor.device_mesh import init_device_mesh -from torch.distributed.checkpoint.state_dict import (StateDictOptions, - get_model_state_dict) +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, +) from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform from torch.nn.parallel import DistributedDataParallel as DDP @@ -37,13 +39,13 @@ def _get_all_inputs( world_size: int = dist.get_world_size() rank: int = dist.get_rank() device: torch.device = torch.device(f'cuda:{rank}') - all_inputs = [] - for _ in range(world_size): - all_inputs.append(torch.rand( + all_inputs = [ + torch.rand( input_shape, device=device, dtype=dtype, - )) + ) for _ in range(world_size) + ] return all_inputs @@ -55,16 +57,22 @@ def _get_torch_dtype(fp16: bool, bf16: bool) -> Optional[torch.dtype]: return None -@pytest.mark.skipif(not is_megablocks_imported, - reason='This test needs megablocks module') +@pytest.mark.skipif( + not is_megablocks_imported, + reason='This test needs megablocks module', +) @pytest.mark.gpu @pytest.mark.world_size(2) @pytest.mark.parametrize('moe_num_experts', [8]) @pytest.mark.parametrize('mlp_type', ['glu', 'mlp']) @pytest.mark.parametrize('moe_world_size', [1, 2]) @pytest.mark.parametrize('two_d_input', [True, False]) -def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, - two_d_input: bool): +def test_dmoe( + moe_num_experts: int, + mlp_type: str, + moe_world_size: int, + two_d_input: bool, +): # Generate inputs rank = dist.get_rank() batch_size = 2 @@ -120,11 +128,10 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, mesh_dim_names=('weight_parallel', 'expert_parallel'), ) expert_parallel_group = device_mesh['expert_parallel'].get_group(0) - extra_args.update( - { - 'moe_expert_model_parallelism': True, - 'expert_parallel_group': expert_parallel_group, - },) + extra_args.update({ + 'moe_expert_model_parallelism': True, + 'expert_parallel_group': expert_parallel_group, + },) mp_dmoe_args.update(extra_args) args = megablocks.layers.arguments.Arguments(**mp_dmoe_args,) mb_dmoe = megablocks.layers.dmoe.dMoE(args).to(device) @@ -152,9 +159,10 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, mb_dmoe.experts = DDP(mb_dmoe.experts, process_group=dp_pg) # Copy mb_dmoe's parameters to torch_dmoe - mb_dmoe_state_dict = get_model_state_dict(mb_dmoe, - options=StateDictOptions( - full_state_dict=True,)) + mb_dmoe_state_dict = get_model_state_dict( + mb_dmoe, + options=StateDictOptions(full_state_dict=True,), + ) for key, t in mb_dmoe_state_dict.items(): if key in tp_names: dtensor_full = DTensor.from_local( @@ -166,9 +174,10 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, mb_dmoe_state_dict[key] = dtensor_full else: mb_dmoe.experts = DDP(mb_dmoe.experts, device_ids=[rank]) - mb_dmoe_state_dict = get_model_state_dict(mb_dmoe, - options=StateDictOptions( - full_state_dict=True,)) + mb_dmoe_state_dict = get_model_state_dict( + mb_dmoe, + options=StateDictOptions(full_state_dict=True,), + ) mb_dmoe_optimizer = optim.SGD(mb_dmoe.parameters(), lr=0.1) # Load mb_dmoe state dict to torch dmoe @@ -188,45 +197,51 @@ def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int, torch.testing.assert_close(torch_y, mb_y) -@pytest.mark.skipif(not is_megablocks_imported, - reason='This test needs megablocks module') +@pytest.mark.skipif( + not is_megablocks_imported, + reason='This test needs megablocks module', +) @pytest.mark.gpu @pytest.mark.parametrize('seqlen', [512]) @pytest.mark.parametrize('mlp_type', ['glu', 'mlp']) @pytest.mark.parametrize('precision', ['bf16', 'fp32']) def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): - mb_dmoe_config = MPTConfig(d_model=1024, - n_heads=32, - n_layers=1, - learned_pos_emb=False, - max_seq_len=2048, - vocab_size=100, - no_bias=True, - fuse_norm_attn_norm=True, - tie_word_embeddings=False, - attn_config=dict( - attn_type='grouped_query_attention', - attn_impl='torch', - attn_pdrop=0.0, - clip_qkv=8.0, - kv_n_heads=8, - rope=True, - rope_theta=10000.0, - ), - ffn_config=dict( - ffn_type='mb_dmoe', - fc_type='torch', - mlp_type=mlp_type, - moe_world_size=1, - ffn_act_fn={'name': 'silu'}, - ffn_hidden_size=1792, - moe_num_experts=16, - moe_top_k=4, - moe_jitter_eps=0.0, - moe_loss_weight=0.05, - moe_normalize_expert_weights=1.0, - uniform_expert_assignment=False, - )) + mb_dmoe_config = MPTConfig( + d_model=1024, + n_heads=32, + n_layers=1, + learned_pos_emb=False, + max_seq_len=2048, + vocab_size=100, + no_bias=True, + fuse_norm_attn_norm=True, + tie_word_embeddings=False, + attn_config={ + 'attn_type': 'grouped_query_attention', + 'attn_impl': 'torch', + 'attn_pdrop': 0.0, + 'clip_qkv': 8.0, + 'kv_n_heads': 8, + 'rope': True, + 'rope_theta': 10000.0, + }, + ffn_config={ + 'ffn_type': 'mb_dmoe', + 'fc_type': 'torch', + 'mlp_type': mlp_type, + 'moe_world_size': 1, + 'ffn_act_fn': { + 'name': 'silu', + }, + 'ffn_hidden_size': 1792, + 'moe_num_experts': 16, + 'moe_top_k': 4, + 'moe_jitter_eps': 0.0, + 'moe_loss_weight': 0.05, + 'moe_normalize_expert_weights': 1.0, + 'uniform_expert_assignment': False, + }, + ) device = 'cuda:0' if precision == 'fp32': dtype = torch.float32 @@ -244,10 +259,12 @@ def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str): del torch_dmoe_config.ffn_config['moe_loss_weight'] del torch_dmoe_config.ffn_config['return_bias'] - mb_dmoe_model = MPTForCausalLM(mb_dmoe_config).to(device=device, - dtype=dtype) - torch_dmoe_model = MPTForCausalLM(torch_dmoe_config).to(device=device, - dtype=dtype) + mb_dmoe_model = MPTForCausalLM( + mb_dmoe_config, + ).to(device=device, dtype=dtype) + torch_dmoe_model = MPTForCausalLM( + torch_dmoe_config, + ).to(device=device, dtype=dtype) # set same state dicts torch_dmoe_model.load_state_dict(mb_dmoe_model.state_dict()) diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index a3b17c36df..dcce0fe118 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -7,15 +7,22 @@ import torch from llmfoundry.models.layers.attention import ( - attn_bias_shape, build_attn_bias, check_alibi_support, flash_attn_fn, - gen_slopes, is_flash_v2_installed, scaled_multihead_dot_product_attention) + attn_bias_shape, + build_attn_bias, + check_alibi_support, + flash_attn_fn, + gen_slopes, + is_flash_v2_installed, + scaled_multihead_dot_product_attention, +) from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info @pytest.mark.gpu @pytest.mark.skipif( not is_flash_v2_installed(), - reason='GQA natively only supported by Flash Attention after v2.') + reason='GQA natively only supported by Flash Attention after v2.', +) @pytest.mark.parametrize('kv_n_heads', [1, 4, 8]) def test_gqa_kv_repetition(kv_n_heads: int): # Test that flash attention v2 with GQA (kv_n_heads < n_heads) works the same @@ -48,8 +55,15 @@ def test_gqa_kv_repetition(kv_n_heads: int): training=False, needs_weights=False, flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, seqlen_1, 0, query_1.device, None, None), - should_repeat_kv_for_gqa=True) + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), + should_repeat_kv_for_gqa=True, + ) output_1.sum().backward() @@ -75,8 +89,15 @@ def test_gqa_kv_repetition(kv_n_heads: int): training=False, needs_weights=False, flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, seqlen_1, 0, query_2.device, None, None), - should_repeat_kv_for_gqa=False) + bsz, + seqlen_1, + 0, + query_2.device, + None, + None, + ), + should_repeat_kv_for_gqa=False, + ) output_2.sum().backward() assert torch.allclose(output_1, output_2) @@ -89,7 +110,7 @@ def test_gqa_kv_repetition(kv_n_heads: int): @pytest.mark.skipif( not is_flash_v2_installed(v2_version='v2.1.2'), reason= - 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' + 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.', ) def test_seq_id_masking_FA_v2(): # Test that flash attention v2 with sequence id masking works correctly. @@ -108,14 +129,22 @@ def test_seq_id_masking_FA_v2(): value_1.requires_grad = True seq_ranges = [ - (0, 3), (3, 5), (5, 6) + (0, 3), + (3, 5), + (5, 6), ] # Each batch has 3 sequences of length 3, 2, and 1 respectively. attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0], [3, 2, 1, 0, 0, 0]]).to(torch.int64).cuda() flash_attn_padding_info_1 = gen_flash_attn_padding_info( - bsz, seqlen_1, 0, query_1.device, attention_mask_in_length_1, None) + bsz, + seqlen_1, + 0, + query_1.device, + attention_mask_in_length_1, + None, + ) output_1, _, _ = flash_attn_fn( query=query_1, @@ -131,7 +160,8 @@ def test_seq_id_masking_FA_v2(): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=flash_attn_padding_info_1) + flash_attn_padding_info=flash_attn_padding_info_1, + ) output_1.sum().backward() @@ -144,7 +174,13 @@ def test_seq_id_masking_FA_v2(): value_2.requires_grad = True flash_attn_padding_info_2 = gen_flash_attn_padding_info( - bsz, seq_range[1] - seq_range[0], 0, query_2.device, None, None) + bsz, + seq_range[1] - seq_range[0], + 0, + query_2.device, + None, + None, + ) output_2, _, _ = flash_attn_fn( query=query_2, @@ -160,27 +196,34 @@ def test_seq_id_masking_FA_v2(): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=flash_attn_padding_info_2) + flash_attn_padding_info=flash_attn_padding_info_2, + ) output_2.sum().backward() - assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], - output_2) + assert torch.allclose( + output_1[:, seq_range[0]:seq_range[1], :], + output_2, + ) assert torch.allclose( query_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore - query_2.grad) # type: ignore + query_2.grad, # type: ignore + ) assert torch.allclose( key_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore - key_2.grad) # type: ignore + key_2.grad, # type: ignore + ) assert torch.allclose( value_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore - value_2.grad) # type: ignore + value_2.grad, # type: ignore + ) @pytest.mark.gpu @pytest.mark.skipif( not is_flash_v2_installed(v2_version='v2.3.0'), reason= - 'Sliding window attention only supported by Flash Attention after v2.3.0.') + 'Sliding window attention only supported by Flash Attention after v2.3.0.', +) @pytest.mark.parametrize('sliding_window_size', [1, 4, 8]) def test_sliding_window(sliding_window_size: int): # Test that sliding window attention works as expected. @@ -191,14 +234,14 @@ def test_sliding_window(sliding_window_size: int): seqlen_1 = 8 bsz = 2 - query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, - device=device) + query_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) query_1.requires_grad = True - key_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, - device=device) + key_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) key_1.requires_grad = True - value_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, - device=device) + value_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) value_1.requires_grad = True output_1, _, _ = flash_attn_fn( @@ -216,9 +259,16 @@ def test_sliding_window(sliding_window_size: int): training=False, needs_weights=False, flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, seqlen_1, 0, query_1.device, None, None), + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), should_repeat_kv_for_gqa=True, - sliding_window_size=sliding_window_size) + sliding_window_size=sliding_window_size, + ) output_1.sum().backward() @@ -229,12 +279,13 @@ def test_sliding_window(sliding_window_size: int): value_2 = value_1.detach().clone() value_2.requires_grad = True - attn_bias_2 = torch.zeros(1, 1, seqlen_1, seqlen_1).to(dtype=dtype, - device=device) + attn_bias_2 = torch.zeros(1, 1, seqlen_1, + seqlen_1).to(dtype=dtype, device=device) window_mask_2 = torch.tril( - torch.ones(seqlen_1, seqlen_1), diagonal=-(sliding_window_size + 1)).to( - dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min + torch.ones(seqlen_1, seqlen_1), + diagonal=-(sliding_window_size + 1), + ).to(dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min attn_bias_2 = attn_bias_2 + window_mask_2 output_2, _, _ = scaled_multihead_dot_product_attention( query=query_2, @@ -268,7 +319,8 @@ def test_sliding_window(sliding_window_size: int): @pytest.mark.gpu @pytest.mark.skipif( not check_alibi_support('flash'), - reason='ALiBi only supported by Flash Attention after v2.4.2.') + reason='ALiBi only supported by Flash Attention after v2.4.2.', +) @pytest.mark.parametrize('n_heads', [1, 6, 8]) def test_alibi_bias(n_heads: int): # Test that sliding window attention works as expected. @@ -278,19 +330,21 @@ def test_alibi_bias(n_heads: int): seqlen_1 = 8 bsz = 2 - query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, - device=device) + query_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) query_1.requires_grad = True - key_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, - device=device) + key_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) key_1.requires_grad = True - value_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, - device=device) + value_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) value_1.requires_grad = True - alibi_slopes_1 = gen_slopes(n_heads=n_heads, - alibi_bias_max=8, - device=torch.device(device), - return_1d=True) + alibi_slopes_1 = gen_slopes( + n_heads=n_heads, + alibi_bias_max=8, + device=torch.device(device), + return_1d=True, + ) output_1, _, _ = flash_attn_fn( query=query_1, key=key_1, @@ -306,9 +360,16 @@ def test_alibi_bias(n_heads: int): training=False, needs_weights=False, flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, seqlen_1, 0, query_1.device, None, None), + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), should_repeat_kv_for_gqa=True, - alibi_slopes=alibi_slopes_1) + alibi_slopes=alibi_slopes_1, + ) output_1.sum().backward() @@ -321,12 +382,14 @@ def test_alibi_bias(n_heads: int): def gen_bias(): causal = True - bs = attn_bias_shape('torch', - n_heads, - seqlen_1, - True, - use_sequence_id=False, - causal=causal) + bs = attn_bias_shape( + 'torch', + n_heads, + seqlen_1, + True, + use_sequence_id=False, + causal=causal, + ) attn_bias = torch.zeros(*bs, device=device) attn_bias = build_attn_bias( diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index f212665c93..669a6a93a1 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -6,19 +6,26 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers import attention -from llmfoundry.models.layers.attention import (check_alibi_support, gen_slopes, - is_flash_v2_installed) +from llmfoundry.models.layers.attention import ( + check_alibi_support, + gen_slopes, + is_flash_v2_installed, +) from llmfoundry.models.layers.layer_builders import build_attention_layer -from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id, - gen_attention_mask_in_length, - gen_flash_attn_padding_info, - gen_rotary_embedding) - - -def allclose_helper(t0: torch.Tensor, - t1: torch.Tensor, - rtol: float = 1e-2, - atol: float = 1e-2): +from llmfoundry.models.mpt.modeling_mpt import ( + apply_sequence_id, + gen_attention_mask_in_length, + gen_flash_attn_padding_info, + gen_rotary_embedding, +) + + +def allclose_helper( + t0: torch.Tensor, + t1: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, +): return torch.allclose(t0, t1, rtol=rtol, atol=atol) @@ -27,52 +34,61 @@ def allclose_helper(t0: torch.Tensor, ('flash', 'torch'), ]) @pytest.mark.parametrize('clip_qkv', [True, False]) -@pytest.mark.parametrize('qk_ln, qk_gn', [ - (True, False), - (False, True), - (False, False), -]) -@pytest.mark.parametrize('pos_emb_config', [{ - 'alibi': False, - 'rope': False -}, { - 'alibi': True, - 'rope': False -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'hf', - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, -}]) +@pytest.mark.parametrize( + 'qk_ln, qk_gn', + [ + (True, False), + (False, True), + (False, False), + ], +) +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': False, + 'rope': False, + }, { + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, + }], +) @pytest.mark.parametrize( 'attn_type', - ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) + ['multihead_attention', 'multiquery_attention', 'grouped_query_attention'], +) @pytest.mark.parametrize('attn_uses_sequence_id', [True, False]) @pytest.mark.parametrize('pad_attention_mask', [True, False]) -def test_attn_impl(attn_impl_0: str, - attn_impl_1: str, - clip_qkv: bool, - qk_ln: bool, - qk_gn: bool, - pos_emb_config: dict, - attn_type: str, - attn_uses_sequence_id: bool, - pad_attention_mask: bool, - device: str = 'cuda'): +def test_attn_impl( + attn_impl_0: str, + attn_impl_1: str, + clip_qkv: bool, + qk_ln: bool, + qk_gn: bool, + pos_emb_config: dict, + attn_type: str, + attn_uses_sequence_id: bool, + pad_attention_mask: bool, + device: str = 'cuda', +): """Compare all attn impl with each other. Includes testing with and without attn_clip_qkv, attn_qk_ln, attn_qk_gn, @@ -80,18 +96,19 @@ def test_attn_impl(attn_impl_0: str, """ alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] - if alibi and not (check_alibi_support(attn_impl_0) and - check_alibi_support(attn_impl_1)): + if alibi and not ( + check_alibi_support(attn_impl_0) and check_alibi_support(attn_impl_1) + ): pytest.skip('flash attention below v2.4.2 does not support alibi.') if rope and (pos_emb_config['rope_impl'] == 'dail') and (not is_flash_v2_installed()): pytest.skip('dail implementation of rope requires flash attention 2.') if attn_uses_sequence_id and ( - attn_impl_0 == 'flash' or attn_impl_1 - == 'flash') and (not is_flash_v2_installed(v2_version='v2.1.2')): + attn_impl_0 == 'flash' or attn_impl_1 == 'flash' + ) and (not is_flash_v2_installed(v2_version='v2.1.2')): pytest.skip( - 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' + 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.', ) if not (alibi or rope) and attn_uses_sequence_id: @@ -116,9 +133,10 @@ def test_attn_impl(attn_impl_0: str, if attn_uses_sequence_id: assert n == 2 assert s >= 4 - sequence_id = torch.LongTensor([[0] * 2 + [1] * (s - 2), - [0] * 4 + [1] * (s - 4) - ]).to(device=device) + sequence_id = torch.LongTensor([ + [0] * 2 + [1] * (s - 2), + [0] * 4 + [1] * (s - 4), + ]).to(device=device) cfg.attn_impl = attn_impl_0 attn0 = build_attention_layer( @@ -141,18 +159,21 @@ def test_attn_impl(attn_impl_0: str, attention_mask[:, -s // 3:] = 0 if sequence_id is not None: sequence_id = sequence_id.masked_fill( - ~attention_mask, -1 + ~attention_mask, + -1, ) # Similar to how we set sequence id for padded tokens: https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249 def gen_bias(attn_impl: str): causal = True attn_bias = None - bs = attention.attn_bias_shape(attn_impl, - cfg.n_heads, - s, - alibi, - use_sequence_id=attn_uses_sequence_id, - causal=causal) + bs = attention.attn_bias_shape( + attn_impl, + cfg.n_heads, + s, + alibi, + use_sequence_id=attn_uses_sequence_id, + causal=causal, + ) if bs is not None: attn_bias = torch.zeros(*bs, device=device) attn_bias = attention.build_attn_bias( @@ -169,7 +190,8 @@ def gen_bias(attn_impl: str): attn_bias = apply_sequence_id( attn_bias, sequence_id, # type: ignore - s) + s, + ) return attn_bias @@ -178,26 +200,38 @@ def gen_bias(attn_impl: str): S=s, attn_uses_sequence_id=attn_uses_sequence_id, attn_impl=attn_impl_0, - attention_mask=attention_mask) + attention_mask=attention_mask, + ) flash_attn_padding_info_0 = {} if attn_impl_0 == 'flash': flash_attn_padding_info_0 = gen_flash_attn_padding_info( - n, s, 0, torch.device(device), attention_mask_in_length_0, - attention_mask) + n, + s, + 0, + torch.device(device), + attention_mask_in_length_0, + attention_mask, + ) attention_mask_in_length_1 = gen_attention_mask_in_length( sequence_id=sequence_id, S=s, attn_uses_sequence_id=attn_uses_sequence_id, attn_impl=attn_impl_1, - attention_mask=attention_mask) + attention_mask=attention_mask, + ) flash_attn_padding_info_1 = {} if attn_impl_1 == 'flash': flash_attn_padding_info_1 = gen_flash_attn_padding_info( - n, s, 0, torch.device(device), attention_mask_in_length_1, - attention_mask) + n, + s, + 0, + torch.device(device), + attention_mask_in_length_1, + attention_mask, + ) x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() @@ -208,10 +242,12 @@ def gen_bias(attn_impl: str): attn_bias_0 = gen_bias(attn_impl_0) alibi_slopes_0 = None if alibi and attn_impl_0 == 'flash': - alibi_slopes_0 = gen_slopes(n_heads=cfg.n_heads, - alibi_bias_max=8, - device=torch.device(device), - return_1d=True) + alibi_slopes_0 = gen_slopes( + n_heads=cfg.n_heads, + alibi_bias_max=8, + device=torch.device(device), + return_1d=True, + ) rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( @@ -220,7 +256,8 @@ def gen_bias(attn_impl: str): rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), rope_hf_config=pos_emb_config.get('rope_hf_config', {}), - max_seq_len=s).to(device) + max_seq_len=s, + ).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens pos = torch.clamp( @@ -238,29 +275,35 @@ def gen_bias(attn_impl: str): s, } - y0, _, _ = attn0(x0, - past_key_value=None, - attn_bias=attn_bias_0, - attention_mask=attention_mask, - rotary_emb_w_meta_info=rotary_emb_w_meta_info, - is_causal=True, - flash_attn_padding_info=flash_attn_padding_info_0, - alibi_slopes=alibi_slopes_0) + y0, _, _ = attn0( + x0, + past_key_value=None, + attn_bias=attn_bias_0, + attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info_0, + alibi_slopes=alibi_slopes_0, + ) attn_bias_1 = gen_bias(attn_impl_1) alibi_slopes_1 = None if alibi and attn_impl_1 == 'flash': - alibi_slopes_1 = gen_slopes(n_heads=cfg.n_heads, - alibi_bias_max=8, - device=torch.device(device), - return_1d=True) - y1, _, _ = attn1(x1, - past_key_value=None, - attn_bias=attn_bias_1, - attention_mask=attention_mask, - rotary_emb_w_meta_info=rotary_emb_w_meta_info, - is_causal=True, - flash_attn_padding_info=flash_attn_padding_info_1, - alibi_slopes=alibi_slopes_1) + alibi_slopes_1 = gen_slopes( + n_heads=cfg.n_heads, + alibi_bias_max=8, + device=torch.device(device), + return_1d=True, + ) + y1, _, _ = attn1( + x1, + past_key_value=None, + attn_bias=attn_bias_1, + attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info_1, + alibi_slopes=alibi_slopes_1, + ) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) @@ -272,19 +315,21 @@ def gen_bias(attn_impl: str): assert allclose_helper(y0, y1) - torch_name_param_map = {n: p for n, p in attn1.named_parameters()} + torch_name_param_map = dict(attn1.named_parameters()) for n, p in attn0.named_parameters(): tp = torch_name_param_map[n] assert p.grad is not None assert tp.grad is not None assert allclose_helper(p, tp) - using_hf_rope = pos_emb_config['rope'] and pos_emb_config[ - 'rope_impl'] == 'hf' + using_hf_rope = pos_emb_config['rope'] and pos_emb_config['rope_impl' + ] == 'hf' # special case that (likely) fails due to numerics - if (clip_qkv and (qk_ln or qk_gn) and using_hf_rope and - attn_type == 'grouped_query_attention'): + if ( + clip_qkv and (qk_ln or qk_gn) and using_hf_rope and + attn_type == 'grouped_query_attention' + ): assert allclose_helper(p.grad, tp.grad, atol=2.e-2, rtol=2.e-2) else: assert allclose_helper(p.grad, tp.grad) @@ -345,19 +390,29 @@ def gen_tca_mask(): flash_attn_padding_info = None if attn_impl == 'flash': flash_attn_padding_info = gen_flash_attn_padding_info( - n, s, 0, torch.device(device), None, attention_mask) - y0, _, _ = mmhsa(x0, - past_key_value=None, - attn_bias=None, - attention_mask=attention_mask, - is_causal=True, - flash_attn_padding_info=flash_attn_padding_info) - y1, _ = tmhsa(x1, - x1, - x1, - attn_mask=gen_tca_mask(), - key_padding_mask=~attention_mask, - need_weights=True) + n, + s, + 0, + torch.device(device), + None, + attention_mask, + ) + y0, _, _ = mmhsa( + x0, + past_key_value=None, + attn_bias=None, + attention_mask=attention_mask, + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info, + ) + y1, _ = tmhsa( + x1, + x1, + x1, + attn_mask=gen_tca_mask(), + key_padding_mask=~attention_mask, + need_weights=True, + ) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) @@ -383,8 +438,10 @@ def gen_tca_mask(): assert allclose_helper(y0, y1) assert allclose_helper(tmhsa.out_proj.bias.grad, mmhsa.out_proj.bias.grad) - assert allclose_helper(tmhsa.out_proj.weight.grad, - mmhsa.out_proj.weight.grad) + assert allclose_helper( + tmhsa.out_proj.weight.grad, + mmhsa.out_proj.weight.grad, + ) assert allclose_helper(tmhsa.in_proj_bias.grad, mmhsa.Wqkv.bias.grad) assert allclose_helper(tmhsa.in_proj_weight.grad, mmhsa.Wqkv.weight.grad) @@ -395,10 +452,12 @@ def gen_tca_mask(): @pytest.mark.parametrize('attn_impl', ['flash', 'torch']) @pytest.mark.parametrize('n_heads', [16, 8]) @pytest.mark.parametrize('kv_n_heads', [4, 2, 1]) -def test_grouped_attention_heads(attn_impl: str, - n_heads: int, - kv_n_heads: int, - device: str = 'cuda'): +def test_grouped_attention_heads( + attn_impl: str, + n_heads: int, + kv_n_heads: int, + device: str = 'cuda', +): """Ensure grouped_query_attention runs w/ diff n_heads & kv_n_heads.""" from llmfoundry.models.layers import attention @@ -409,7 +468,7 @@ def test_grouped_attention_heads(attn_impl: str, 'attn_pdrop': 0, 'clip_qkv': False, 'qk_ln': False, - 'kv_n_heads': kv_n_heads + 'kv_n_heads': kv_n_heads, }) n, s, f = 2, 4, cfg.d_model @@ -424,13 +483,21 @@ def test_grouped_attention_heads(attn_impl: str, flash_attn_padding_info = None if attn_impl == 'flash': flash_attn_padding_info = gen_flash_attn_padding_info( - n, s, 0, torch.device(device), None, attention_mask) - y0, _, _ = mmhsa(x0, - past_key_value=None, - attn_bias=None, - attention_mask=attention_mask, - is_causal=True, - flash_attn_padding_info=flash_attn_padding_info) + n, + s, + 0, + torch.device(device), + None, + attention_mask, + ) + y0, _, _ = mmhsa( + x0, + past_key_value=None, + attn_bias=None, + attention_mask=attention_mask, + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info, + ) y0 *= attention_mask.unsqueeze(-1) loss0 = y0.sum() @@ -449,7 +516,7 @@ def test_grouped_query_invalid_heads(): 'attn_pdrop': 0, 'clip_qkv': False, 'qk_ln': False, - 'kv_n_heads': 3 + 'kv_n_heads': 3, }) expected_error = 'Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads' diff --git a/tests/models/layers/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py index 08891d5199..88113cf55b 100644 --- a/tests/models/layers/test_huggingface_flash.py +++ b/tests/models/layers/test_huggingface_flash.py @@ -33,7 +33,9 @@ def test_flash2(model_name: str, use_flash_attention_2: bool, init_device: str): tokenizer_name = 'codellama/CodeLlama-7b-hf' from transformers.models.llama.modeling_llama import ( - LlamaAttention, LlamaFlashAttention2) + LlamaAttention, + LlamaFlashAttention2, + ) flash_attn_class = LlamaFlashAttention2 if use_flash_attention_2 else LlamaAttention attention_layers_attr = 'model.model.layers' attention_attr = 'self_attn' @@ -52,7 +54,8 @@ def test_flash2(model_name: str, use_flash_attention_2: bool, init_device: str): tokenizer.pad_token = tokenizer.eos_token error_context = pytest.raises( - ValueError, match='use_flash_attention_2 is set to True' + ValueError, + match='use_flash_attention_2 is set to True', ) if not is_flash_v2_installed( ) and use_flash_attention_2 else contextlib.nullcontext() @@ -65,19 +68,24 @@ def test_flash2(model_name: str, use_flash_attention_2: bool, init_device: str): # check that it actually used flash attention 2 assert model.model.config._attn_implementation == ( - 'flash_attention_2' if use_flash_attention_2 else 'eager') + 'flash_attention_2' if use_flash_attention_2 else 'eager' + ) attention_layer = rgetattr( - rgetattr(model, attention_layers_attr)[0], attention_attr) + rgetattr(model, attention_layers_attr)[0], + attention_attr, + ) assert isinstance(attention_layer, flash_attn_class) # Skip attempting to run forward/backward when some devices have meta params # because we are not instantiating a full Trainer here, which contains the logic # to move params off of meta device. if init_device == 'cpu': - tokenized_input = tokenizer( - ['Hello world blah blah', 'Goodbye world'], - return_tensors='pt', - padding=True) + tokenized_input = tokenizer([ + 'Hello world blah blah', + 'Goodbye world', + ], + return_tensors='pt', + padding=True) tokenized_input['labels'] = tokenized_input['input_ids'].clone() tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()} diff --git a/tests/models/test_fsdp_act_checkpoint.py b/tests/models/test_fsdp_act_checkpoint.py index 97063b25c4..ab5f2705b4 100644 --- a/tests/models/test_fsdp_act_checkpoint.py +++ b/tests/models/test_fsdp_act_checkpoint.py @@ -16,15 +16,22 @@ @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize('activation_checkpointing', [True, False]) -@pytest.mark.parametrize('activation_checkpointing_target', [ - 'grouped_query_attention', [], ['grouped_query_attention'], { - 'mptblock': [1], - 'grouped_query_attention': 'first-1, last-1' - } -]) -def test_fsdp_act_checkpoint(activation_checkpointing: bool, - activation_checkpointing_target: Union[list, str, - dict]): +@pytest.mark.parametrize( + 'activation_checkpointing_target', + [ + 'grouped_query_attention', + [], + ['grouped_query_attention'], + { + 'mptblock': [1], + 'grouped_query_attention': 'first-1, last-1', + }, + ], +) +def test_fsdp_act_checkpoint( + activation_checkpointing: bool, + activation_checkpointing_target: Union[list, str, dict], +): device = get_device('gpu') model_cfg = { 'name': 'mpt_causal_lm', @@ -38,7 +45,7 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool, 'attn_type': 'grouped_query_attention', 'kv_n_heads': 2, }, - 'activation_checkpointing_target': activation_checkpointing_target + 'activation_checkpointing_target': activation_checkpointing_target, } model_cfg = om.create(model_cfg) @@ -61,31 +68,41 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool, if not activation_checkpointing: assert not isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. - blocks[0], CheckpointWrapper) + blocks[0], + CheckpointWrapper, + ) elif (not activation_checkpointing_target): module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[ 0]._fsdp_wrapped_module assert isinstance(module, CheckpointWrapper) elif activation_checkpointing_target == [ - 'grouped_query_attention' + 'grouped_query_attention', ] or activation_checkpointing_target == 'grouped_query_attention': assert isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. - blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) + blocks[0]._fsdp_wrapped_module.attn, + CheckpointWrapper, + ) elif activation_checkpointing_target == { - 'mptblock': [1], - 'grouped_query_attention': 'first-1, last-1' + 'mptblock': [1], + 'grouped_query_attention': 'first-1, last-1', }: assert isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. - blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) + blocks[0]._fsdp_wrapped_module.attn, + CheckpointWrapper, + ) assert isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. - blocks[1]._fsdp_wrapped_module, CheckpointWrapper) + blocks[1]._fsdp_wrapped_module, + CheckpointWrapper, + ) assert isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. - blocks[2]._fsdp_wrapped_module.attn, CheckpointWrapper) + blocks[2]._fsdp_wrapped_module.attn, + CheckpointWrapper, + ) else: raise ValueError( - f'Unknown activation_checkpointing_target: {activation_checkpointing_target}' + f'Unknown activation_checkpointing_target: {activation_checkpointing_target}', ) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 402698cb27..243e45b671 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -19,9 +19,14 @@ from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om -from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, - PreTrainedTokenizer, PreTrainedTokenizerFast, - pipeline) +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + pipeline, +) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.bloom.modeling_bloom import build_alibi_tensor @@ -29,8 +34,10 @@ from llmfoundry.layers_registry import norms from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP from llmfoundry.models.layers import build_alibi_bias -from llmfoundry.models.layers.attention import (check_alibi_support, - is_flash_v2_installed) +from llmfoundry.models.layers.attention import ( + check_alibi_support, + is_flash_v2_installed, +) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer @@ -38,7 +45,7 @@ def get_config( - conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml' + conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', ) -> DictConfig: os.environ['TOKENIZERS_PARALLELISM'] = 'false' print(conf_path) @@ -53,17 +60,22 @@ def _load_tokenizer_cfg(cfg: DictConfig) -> Dict: return config -def _get_objs(request: pytest.FixtureRequest, - conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): +def _get_objs( + request: pytest.FixtureRequest, + conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', +): warnings.filterwarnings( action='ignore', - message='Torchmetrics v0.9 introduced a new argument class property') + message='Torchmetrics v0.9 introduced a new argument class property', + ) test_cfg = get_config(conf_path=conf_path) # Read FSDP Config as a dict fsdp_config = test_cfg.get('fsdp_config', None) - fsdp_config = om.to_container(fsdp_config, - resolve=True) if fsdp_config else None + fsdp_config = om.to_container( + fsdp_config, + resolve=True, + ) if fsdp_config else None # Check if we are running on GPU is_gpu = False @@ -86,8 +98,10 @@ def _get_objs(request: pytest.FixtureRequest, test_cfg.device_train_microbatch_size = 2 tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) - tokenizer = build_tokenizer(test_cfg.tokenizer.name, - tokenizer_cfg.get('kwargs', {})) + tokenizer = build_tokenizer( + test_cfg.tokenizer.name, + tokenizer_cfg.get('kwargs', {}), + ) model = build_composer_model( name=test_cfg.model.name, @@ -97,18 +111,22 @@ def _get_objs(request: pytest.FixtureRequest, # Optimizer assert test_cfg.optimizer.name == 'decoupled_adamw' - optimizer = DecoupledAdamW(model.parameters(), - lr=test_cfg.optimizer.lr, - betas=test_cfg.optimizer.betas, - eps=test_cfg.optimizer.eps, - weight_decay=test_cfg.optimizer.weight_decay) + optimizer = DecoupledAdamW( + model.parameters(), + lr=test_cfg.optimizer.lr, + betas=test_cfg.optimizer.betas, + eps=test_cfg.optimizer.eps, + weight_decay=test_cfg.optimizer.weight_decay, + ) return test_cfg, model, optimizer -def gen_random_batch(batch_size: int, - test_cfg: Union[DictConfig, ListConfig], - inputs: Optional[List[str]] = None): +def gen_random_batch( + batch_size: int, + test_cfg: Union[DictConfig, ListConfig], + inputs: Optional[List[str]] = None, +): # inputs can be [], ['input_ids'], ['input_ids', 'inputs_embeds'], and ['inputs_embeds'] # default to only input ids if inputs == None: @@ -120,53 +138,74 @@ def gen_random_batch(batch_size: int, batch['input_ids'] = torch.randint( low=0, high=test_cfg.model.vocab_size, - size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device) + size=(batch_size, test_cfg.max_seq_len), + ).to(test_cfg.device) if inp == 'inputs_embeds': batch['inputs_embeds'] = torch.randn( - batch_size, test_cfg.max_seq_len, - test_cfg.model.d_model).to(test_cfg.device) - - batch['labels'] = torch.randint(low=0, - high=test_cfg.model.vocab_size, - size=(batch_size, test_cfg.max_seq_len)).to( - test_cfg.device) - batch['attention_mask'] = torch.ones(size=(batch_size, - test_cfg.max_seq_len), - dtype=torch.int64).to(test_cfg.device) + batch_size, + test_cfg.max_seq_len, + test_cfg.model.d_model, + ).to(test_cfg.device) + + batch['labels'] = torch.randint( + low=0, + high=test_cfg.model.vocab_size, + size=(batch_size, test_cfg.max_seq_len), + ).to(test_cfg.device) + batch['attention_mask'] = torch.ones( + size=(batch_size, test_cfg.max_seq_len), + dtype=torch.int64, + ).to(test_cfg.device) return batch -def gen_random_enc_dec_batch(batch_size: int, vocab_size: int, max_seq_len: int, - device: str): +def gen_random_enc_dec_batch( + batch_size: int, + vocab_size: int, + max_seq_len: int, + device: str, +): # generate input batch of random data, suitable for a T5 batch = {} - batch['input_ids'] = torch.randint(low=0, - high=vocab_size, - size=(batch_size, - max_seq_len)).to(device) - batch['labels'] = torch.randint(low=0, - high=vocab_size, - size=(batch_size, max_seq_len)).to(device) + batch['input_ids'] = torch.randint( + low=0, + high=vocab_size, + size=(batch_size, max_seq_len), + ).to(device) + batch['labels'] = torch.randint( + low=0, + high=vocab_size, + size=(batch_size, max_seq_len), + ).to(device) batch['decoder_input_ids'] = torch.zeros_like(batch['labels']) batch['decoder_input_ids'][:, 1:] = batch['labels'][:, :-1] - batch['attention_mask'] = torch.ones(size=(batch_size, max_seq_len), - dtype=torch.int64).to(device) + batch['attention_mask'] = torch.ones( + size=(batch_size, max_seq_len), + dtype=torch.int64, + ).to(device) batch['decoder_attention_mask'] = batch['attention_mask'].clone() return batch -@pytest.mark.parametrize('conf_path', [ - 'scripts/train/yamls/pretrain/testing.yaml', -]) -def test_full_forward_and_backward(request: pytest.FixtureRequest, - conf_path: str, - batch_size: int = 2): +@pytest.mark.parametrize( + 'conf_path', + [ + 'scripts/train/yamls/pretrain/testing.yaml', + ], +) +def test_full_forward_and_backward( + request: pytest.FixtureRequest, + conf_path: str, + batch_size: int = 2, +): test_cfg, model, optimizer = _get_objs(request=request, conf_path=conf_path) batch = gen_random_batch(batch_size, test_cfg) - assert batch['input_ids'].shape == torch.Size( - [batch_size, test_cfg.max_seq_len]) + assert batch['input_ids'].shape == torch.Size([ + batch_size, + test_cfg.max_seq_len, + ]) model.train() original_params = next(model.parameters()).clone().data outputs = model(batch) @@ -179,9 +218,13 @@ def test_full_forward_and_backward(request: pytest.FixtureRequest, def test_full_forward_and_backward_with_inputs_embeds( - request: pytest.FixtureRequest, batch_size: int = 2): + request: pytest.FixtureRequest, + batch_size: int = 2, +): test_cfg, model, optimizer = _get_objs( - request=request, conf_path='scripts/train/yamls/pretrain/testing.yaml') + request=request, + conf_path='scripts/train/yamls/pretrain/testing.yaml', + ) batch = gen_random_batch(batch_size, test_cfg, inputs=['inputs_embeds']) @@ -198,9 +241,13 @@ def test_full_forward_and_backward_with_inputs_embeds( @pytest.mark.parametrize('inputs', [[], ['input_ids', 'inputs_embeds']]) def test_invalid_inputs_embeds_input_ids_combinations( - request: pytest.FixtureRequest, inputs: List[str]): + request: pytest.FixtureRequest, + inputs: List[str], +): test_cfg, model, _ = _get_objs( - request=request, conf_path='scripts/train/yamls/pretrain/testing.yaml') + request=request, + conf_path='scripts/train/yamls/pretrain/testing.yaml', + ) batch = gen_random_batch(2, test_cfg, inputs=inputs) @@ -209,22 +256,29 @@ def test_invalid_inputs_embeds_input_ids_combinations( _ = model(batch) -@pytest.mark.parametrize('conf_path', [ - 'scripts/train/yamls/pretrain/testing.yaml', - pytest.param('scripts/train/yamls/pretrain/testing-moe.yaml', - marks=pytest.mark.gpu), -]) -def test_attention_mechanism(request: pytest.FixtureRequest, - conf_path: str, - batch_size: int = 2): +@pytest.mark.parametrize( + 'conf_path', + [ + 'scripts/train/yamls/pretrain/testing.yaml', + pytest.param( + 'scripts/train/yamls/pretrain/testing-moe.yaml', + marks=pytest.mark.gpu, + ), + ], +) +def test_attention_mechanism( + request: pytest.FixtureRequest, + conf_path: str, + batch_size: int = 2, +): test_cfg, model, _ = _get_objs(request=request, conf_path=conf_path) batch = gen_random_batch(batch_size, test_cfg) model.eval() # run a partial forward where we explicitly inspect the attention_mask from the causal_attn block - input_ids, attention_mask = batch['input_ids'], batch[ - 'attention_mask'].bool() + input_ids, attention_mask = batch['input_ids'], batch['attention_mask' + ].bool() _, S = input_ids.size() assert ( @@ -244,7 +298,8 @@ def test_attention_mechanism(request: pytest.FixtureRequest, expected_zerod_weights = nn.Transformer.generate_square_subsequent_mask(test_cfg.max_seq_len, device=test_cfg.device)\ .reshape(1, test_cfg.max_seq_len, test_cfg.max_seq_len) expected_zerod_weights = torch.isneginf( - torch.cat(batch_size * [expected_zerod_weights])) + torch.cat(batch_size * [expected_zerod_weights]), + ) torch_key_padding = torch.cat( # type: ignore test_cfg.max_seq_len * [(~attention_mask).reshape(batch_size, 1, test_cfg.max_seq_len)], @@ -252,7 +307,10 @@ def test_attention_mechanism(request: pytest.FixtureRequest, expected_zerod_weights |= torch_key_padding attn_bias, attention_mask = model.model.transformer._attn_bias( - device=x.device, dtype=x.dtype, attention_mask=attention_mask) + device=x.device, + dtype=x.dtype, + attention_mask=attention_mask, + ) for block in model.model.transformer.blocks: a = block.norm_1(x) @@ -262,12 +320,14 @@ def test_attention_mechanism(request: pytest.FixtureRequest, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=model.model.transformer.is_causal, - needs_weights=True) + needs_weights=True, + ) zerod_weights = (attention_weights == 0) assert torch.equal( expected_zerod_weights.expand(*zerod_weights.shape), - zerod_weights) + zerod_weights, + ) x = x + block.resid_attn_dropout(b) m = block.norm_2(x) n = block.ffn(m) @@ -277,7 +337,8 @@ def test_attention_mechanism(request: pytest.FixtureRequest, def test_full_forward_and_backward_gpt2_small(batch_size: int = 2): warnings.filterwarnings( action='ignore', - message='Torchmetrics v0.9 introduced a new argument class property') + message='Torchmetrics v0.9 introduced a new argument class property', + ) conf_path = 'scripts/train/yamls/pretrain/gpt2-small.yaml' with open(conf_path) as f: neo_cfg = om.load(f) @@ -288,8 +349,10 @@ def test_full_forward_and_backward_gpt2_small(batch_size: int = 2): neo_cfg.model.name = 'hf_causal_lm' tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(neo_cfg.tokenizer) - tokenizer = build_tokenizer(neo_cfg.tokenizer.name, - tokenizer_cfg.get('kwargs', {})) + tokenizer = build_tokenizer( + neo_cfg.tokenizer.name, + tokenizer_cfg.get('kwargs', {}), + ) model = build_composer_model( name=neo_cfg.model.name, @@ -297,22 +360,28 @@ def test_full_forward_and_backward_gpt2_small(batch_size: int = 2): tokenizer=tokenizer, ).to(device) - assert isinstance(model.tokenizer, - (PreTrainedTokenizer, PreTrainedTokenizerFast)) + assert isinstance( + model.tokenizer, + (PreTrainedTokenizer, PreTrainedTokenizerFast), + ) assert neo_cfg.optimizer.name == 'decoupled_adamw' - optimizer = DecoupledAdamW(model.parameters(), - lr=neo_cfg.optimizer.lr, - betas=neo_cfg.optimizer.betas, - eps=neo_cfg.optimizer.eps, - weight_decay=neo_cfg.optimizer.weight_decay) + optimizer = DecoupledAdamW( + model.parameters(), + lr=neo_cfg.optimizer.lr, + betas=neo_cfg.optimizer.betas, + eps=neo_cfg.optimizer.eps, + weight_decay=neo_cfg.optimizer.weight_decay, + ) # set vocab size using model num_embeddings neo_cfg.model.vocab_size = model.model.transformer.wte.num_embeddings batch = gen_random_batch(batch_size, neo_cfg) - assert batch['input_ids'].shape == torch.Size( - [batch_size, neo_cfg.max_seq_len]) + assert batch['input_ids'].shape == torch.Size([ + batch_size, + neo_cfg.max_seq_len, + ]) model.train() original_params = next(model.parameters()).clone().data outputs = model(batch) @@ -327,7 +396,8 @@ def test_full_forward_and_backward_gpt2_small(batch_size: int = 2): def test_full_forward_and_backward_t5_small(batch_size: int = 2): warnings.filterwarnings( action='ignore', - message='Torchmetrics v0.9 introduced a new argument class property') + message='Torchmetrics v0.9 introduced a new argument class property', + ) conf_path = 'scripts/train/yamls/finetune/t5-small_dolly_sft.yaml' with open(conf_path) as f: t5_cfg = om.load(f) @@ -337,8 +407,10 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): t5_cfg.max_seq_len = 16 tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(t5_cfg.tokenizer) - tokenizer = build_tokenizer(t5_cfg.tokenizer.name, - tokenizer_cfg.get('kwargs', {})) + tokenizer = build_tokenizer( + t5_cfg.tokenizer.name, + tokenizer_cfg.get('kwargs', {}), + ) model = build_composer_model( name=t5_cfg.model.name, @@ -346,21 +418,31 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): tokenizer=tokenizer, ).to(device) - assert isinstance(model.tokenizer, - (PreTrainedTokenizer, PreTrainedTokenizerFast)) + assert isinstance( + model.tokenizer, + (PreTrainedTokenizer, PreTrainedTokenizerFast), + ) - optimizer = DecoupledAdamW(model.parameters(), - lr=t5_cfg.optimizer.lr, - betas=t5_cfg.optimizer.betas, - eps=t5_cfg.optimizer.eps, - weight_decay=t5_cfg.optimizer.weight_decay) + optimizer = DecoupledAdamW( + model.parameters(), + lr=t5_cfg.optimizer.lr, + betas=t5_cfg.optimizer.betas, + eps=t5_cfg.optimizer.eps, + weight_decay=t5_cfg.optimizer.weight_decay, + ) # set vocab size using model num_embeddings - batch = gen_random_enc_dec_batch(batch_size, model.model.config.vocab_size, - t5_cfg.max_seq_len, device) + batch = gen_random_enc_dec_batch( + batch_size, + model.model.config.vocab_size, + t5_cfg.max_seq_len, + device, + ) - assert batch['input_ids'].shape == torch.Size( - [batch_size, t5_cfg.max_seq_len]) + assert batch['input_ids'].shape == torch.Size([ + batch_size, + t5_cfg.max_seq_len, + ]) model.train() original_params = next(model.parameters()).clone().data outputs = model(batch) @@ -377,27 +459,37 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): 'attn_impl,precision', [('torch', torch.float16), ('torch', torch.bfloat16), pytest.param('flash', torch.float16, marks=pytest.mark.gpu), - pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)]) + pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)], +) @pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptglu']) -@pytest.mark.parametrize('ffn_act_fn', [ - None, - { - 'name': 'gelu', - 'approximate': 'tanh', - }, - { - 'name': 'silu', - }, - { - 'name': 'relu', - 'inplace': True, - }, - pytest.param({'name': 'relu5'}, - marks=pytest.mark.xfail(reason='invalid choice.', - strict=True)), -]) -def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str, - ffn_act_fn: dict): +@pytest.mark.parametrize( + 'ffn_act_fn', + [ + None, + { + 'name': 'gelu', + 'approximate': 'tanh', + }, + { + 'name': 'silu', + }, + { + 'name': 'relu', + 'inplace': True, + }, + pytest.param({'name': 'relu5'}, + marks=pytest.mark.xfail( + reason='invalid choice.', + strict=True, + )), + ], +) +def test_determinism( + attn_impl: str, + precision: torch.dtype, + ffn_type: str, + ffn_act_fn: dict, +): conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: test_cfg = om.load(f) @@ -414,8 +506,10 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str, test_cfg.device = 'cuda:0' tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) - tokenizer = build_tokenizer(test_cfg.tokenizer.name, - tokenizer_cfg.get('kwargs', {})) + tokenizer = build_tokenizer( + test_cfg.tokenizer.name, + tokenizer_cfg.get('kwargs', {}), + ) model_1 = build_composer_model( name=test_cfg.model.name, @@ -424,24 +518,31 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str, ) model_2 = copy.deepcopy(model_1) - optimizer_1 = DecoupledAdamW(model_1.parameters(), - lr=test_cfg.optimizer.lr, - betas=test_cfg.optimizer.betas, - eps=test_cfg.optimizer.eps, - weight_decay=test_cfg.optimizer.weight_decay) - optimizer_2 = DecoupledAdamW(model_2.parameters(), - lr=test_cfg.optimizer.lr, - betas=test_cfg.optimizer.betas, - eps=test_cfg.optimizer.eps, - weight_decay=test_cfg.optimizer.weight_decay) + optimizer_1 = DecoupledAdamW( + model_1.parameters(), + lr=test_cfg.optimizer.lr, + betas=test_cfg.optimizer.betas, + eps=test_cfg.optimizer.eps, + weight_decay=test_cfg.optimizer.weight_decay, + ) + optimizer_2 = DecoupledAdamW( + model_2.parameters(), + lr=test_cfg.optimizer.lr, + betas=test_cfg.optimizer.betas, + eps=test_cfg.optimizer.eps, + weight_decay=test_cfg.optimizer.weight_decay, + ) for i in range(5): with torch.cuda.amp.autocast(True, precision): batch = gen_random_batch(2, test_cfg) output_1 = model_1(batch) output_2 = model_2(batch) - assert output_1.logits.allclose(output_2.logits, rtol=0.0, - atol=0.0), f'differed at step {i}' + assert output_1.logits.allclose( + output_2.logits, + rtol=0.0, + atol=0.0, + ), f'differed at step {i}' loss_1 = model_1.loss(output_1, batch) loss_2 = model_2.loss(output_2, batch) assert isinstance(loss_1, torch.Tensor) @@ -484,8 +585,10 @@ def test_loss_fn(): } tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) - tokenizer = build_tokenizer(test_cfg.tokenizer.name, - tokenizer_cfg.get('kwargs', {})) + tokenizer = build_tokenizer( + test_cfg.tokenizer.name, + tokenizer_cfg.get('kwargs', {}), + ) model_1 = build_composer_model( name=test_cfg.model.name, @@ -500,30 +603,40 @@ def test_loss_fn(): assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss) model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100, reduction='none') - optimizer_1 = DecoupledAdamW(model_1.parameters(), - lr=test_cfg.optimizer.lr, - betas=test_cfg.optimizer.betas, - eps=test_cfg.optimizer.eps, - weight_decay=test_cfg.optimizer.weight_decay) - optimizer_2 = DecoupledAdamW(model_2.parameters(), - lr=test_cfg.optimizer.lr, - betas=test_cfg.optimizer.betas, - eps=test_cfg.optimizer.eps, - weight_decay=test_cfg.optimizer.weight_decay) + optimizer_1 = DecoupledAdamW( + model_1.parameters(), + lr=test_cfg.optimizer.lr, + betas=test_cfg.optimizer.betas, + eps=test_cfg.optimizer.eps, + weight_decay=test_cfg.optimizer.weight_decay, + ) + optimizer_2 = DecoupledAdamW( + model_2.parameters(), + lr=test_cfg.optimizer.lr, + betas=test_cfg.optimizer.betas, + eps=test_cfg.optimizer.eps, + weight_decay=test_cfg.optimizer.weight_decay, + ) for i in range(15): batch = gen_random_batch(2, test_cfg) output_1 = model_1(batch) output_2 = model_2(batch) - assert output_1.logits.allclose(output_2.logits, rtol=1e-4, - atol=1e-4), f'differed at step {i}' + assert output_1.logits.allclose( + output_2.logits, + rtol=1e-4, + atol=1e-4, + ), f'differed at step {i}' loss_1 = model_1.loss(output_1, batch) loss_2 = model_2.loss(output_2, batch) assert isinstance(loss_1, torch.Tensor) assert isinstance(loss_2, torch.Tensor) - assert loss_1.allclose(loss_2, rtol=1e-3, - atol=1e-3), f'differed at step {i}' + assert loss_1.allclose( + loss_2, + rtol=1e-3, + atol=1e-3, + ), f'differed at step {i}' loss_1.backward() loss_2.backward() optimizer_1.step() @@ -531,13 +644,18 @@ def test_loss_fn(): for p1, p2 in zip(model_1.parameters(), model_2.parameters()): assert p1.data.shape == p2.data.shape - assert p1.data.allclose(p2.data, rtol=1e-5, - atol=1e-4), f'differed at step {i}' + assert p1.data.allclose( + p2.data, + rtol=1e-5, + atol=1e-4, + ), f'differed at step {i}' @pytest.mark.gpu -@pytest.mark.parametrize('loss_fn_config', - ['torch_crossentropy', 'fused_crossentropy']) +@pytest.mark.parametrize( + 'loss_fn_config', + ['torch_crossentropy', 'fused_crossentropy'], +) def test_loss_reduction(loss_fn_config: str): """Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function. @@ -570,8 +688,10 @@ def test_loss_reduction(loss_fn_config: str): } tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) - tokenizer = build_tokenizer(test_cfg.tokenizer.name, - tokenizer_cfg.get('kwargs', {})) + tokenizer = build_tokenizer( + test_cfg.tokenizer.name, + tokenizer_cfg.get('kwargs', {}), + ) model_1 = build_composer_model( name=test_cfg.model.name, @@ -586,30 +706,41 @@ def test_loss_reduction(loss_fn_config: str): # Reduce the loss in FusedCrossEntropyLoss if loss_fn_config == 'fused_crossentropy': assert isinstance(model_1.loss_fn, FusedCrossEntropyLoss) - model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100, - reduction='mean') + model_2.loss_fn = FusedCrossEntropyLoss( + ignore_index=-100, + reduction='mean', + ) else: assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss) - model_2.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, - reduction='mean') - - optimizer_1 = DecoupledAdamW(model_1.parameters(), - lr=test_cfg.optimizer.lr, - betas=test_cfg.optimizer.betas, - eps=test_cfg.optimizer.eps, - weight_decay=test_cfg.optimizer.weight_decay) - optimizer_2 = DecoupledAdamW(model_2.parameters(), - lr=test_cfg.optimizer.lr, - betas=test_cfg.optimizer.betas, - eps=test_cfg.optimizer.eps, - weight_decay=test_cfg.optimizer.weight_decay) + model_2.loss_fn = torch.nn.CrossEntropyLoss( + ignore_index=-100, + reduction='mean', + ) + + optimizer_1 = DecoupledAdamW( + model_1.parameters(), + lr=test_cfg.optimizer.lr, + betas=test_cfg.optimizer.betas, + eps=test_cfg.optimizer.eps, + weight_decay=test_cfg.optimizer.weight_decay, + ) + optimizer_2 = DecoupledAdamW( + model_2.parameters(), + lr=test_cfg.optimizer.lr, + betas=test_cfg.optimizer.betas, + eps=test_cfg.optimizer.eps, + weight_decay=test_cfg.optimizer.weight_decay, + ) for i in range(3): batch = gen_random_batch(2, test_cfg) output_1 = model_1(batch) output_2 = model_2(batch) - assert output_1.logits.allclose(output_2.logits, rtol=1e-4, - atol=1e-4), f'differed at step {i}' + assert output_1.logits.allclose( + output_2.logits, + rtol=1e-4, + atol=1e-4, + ), f'differed at step {i}' loss_1 = model_1.loss(output_1, batch) @@ -617,12 +748,16 @@ def test_loss_reduction(loss_fn_config: str): targets = model_2.get_targets(batch) loss_2 = model_2.loss_fn( output_2.logits.view(-1, output_2.logits.size(-1)), - targets.view(-1)) + targets.view(-1), + ) assert isinstance(loss_1, torch.Tensor) assert isinstance(loss_2, torch.Tensor) - assert loss_1.allclose(loss_2, rtol=1e-3, - atol=1e-3), f'differed at step {i}' + assert loss_1.allclose( + loss_2, + rtol=1e-3, + atol=1e-3, + ), f'differed at step {i}' loss_1.backward() loss_2.backward() optimizer_1.step() @@ -630,17 +765,23 @@ def test_loss_reduction(loss_fn_config: str): for p1, p2 in zip(model_1.parameters(), model_2.parameters()): assert p1.data.shape == p2.data.shape - assert p1.data.allclose(p2.data, rtol=1e-5, - atol=1e-4), f'differed at step {i}' + assert p1.data.allclose( + p2.data, + rtol=1e-5, + atol=1e-4, + ), f'differed at step {i}' -@pytest.mark.parametrize('peft_config', [ - None, - { - 'peft_type': 'LORA', - 'task_type': 'CAUSAL_LM' - }, -]) +@pytest.mark.parametrize( + 'peft_config', + [ + None, + { + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + }, + ], +) def test_opt_wrapping(peft_config: Optional[dict[str, str]]): if peft_config is not None: _ = pytest.importorskip('peft') @@ -649,11 +790,11 @@ def test_opt_wrapping(peft_config: Optional[dict[str, str]]): 'model': { 'name': 'hf_causal_lm', 'pretrained_model_name_or_path': 'facebook/opt-125m', - 'pretrained': False + 'pretrained': False, }, 'tokenizer': { - 'name': 'facebook/opt-125m' - } + 'name': 'facebook/opt-125m', + }, } if peft_config is not None: conf['model']['peft_config'] = peft_config @@ -661,8 +802,10 @@ def test_opt_wrapping(peft_config: Optional[dict[str, str]]): config = DictConfig(conf) tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(config.tokenizer) - tokenizer = build_tokenizer(config.tokenizer.name, - tokenizer_cfg.get('kwargs', {})) + tokenizer = build_tokenizer( + config.tokenizer.name, + tokenizer_cfg.get('kwargs', {}), + ) model = ComposerHFCausalLM(config.model, tokenizer) @@ -685,15 +828,17 @@ def test_lora_id(): 'pretrained_lora_id_or_path': 'ybelkada/opt-350m-lora', }, 'tokenizer': { - 'name': 'facebook/opt-350m' - } + 'name': 'facebook/opt-350m', + }, } config = DictConfig(conf) tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(config.tokenizer) - tokenizer = build_tokenizer(config.tokenizer.name, - tokenizer_cfg.get('kwargs', {})) + tokenizer = build_tokenizer( + config.tokenizer.name, + tokenizer_cfg.get('kwargs', {}), + ) model = ComposerHFCausalLM(config.model, tokenizer) @@ -703,39 +848,55 @@ def test_lora_id(): @pytest.mark.parametrize('norm_type', norms.get_all()) @pytest.mark.parametrize('no_bias', [False, True]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -@pytest.mark.parametrize('expansion_ratio,ffn_hidden_size', [ - (2, None), - pytest.param(1.231, - None, - marks=pytest.mark.xfail( - reason='d_model * expansion_ratio must be an integer.', - strict=True)), - (2, 128), - (2, 256), -]) -@pytest.mark.parametrize('ffn_act_fn', [ - None, - { - 'name': 'gelu', - 'approximate': 'tanh', - }, - { - 'name': 'silu', - }, - { - 'name': 'relu', - 'inplace': True, - }, - pytest.param({'name': 'relu5'}, - marks=pytest.mark.xfail(reason='invalid choice.', - strict=True)), -]) -def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, - expansion_ratio: Union[int, float], ffn_hidden_size: int, - ffn_act_fn: dict): +@pytest.mark.parametrize( + 'expansion_ratio,ffn_hidden_size', + [ + (2, None), + pytest.param( + 1.231, + None, + marks=pytest.mark.xfail( + reason='d_model * expansion_ratio must be an integer.', + strict=True, + ), + ), + (2, 128), + (2, 256), + ], +) +@pytest.mark.parametrize( + 'ffn_act_fn', + [ + None, + { + 'name': 'gelu', + 'approximate': 'tanh', + }, + { + 'name': 'silu', + }, + { + 'name': 'relu', + 'inplace': True, + }, + pytest.param({'name': 'relu5'}, + marks=pytest.mark.xfail( + reason='invalid choice.', + strict=True, + )), + ], +) +def test_mpt_creation( + norm_type: str, + no_bias: bool, + tie_word_embeddings: bool, + expansion_ratio: Union[int, float], + ffn_hidden_size: int, + ffn_act_fn: dict, +): if norm_type == 'triton_rmsnorm' and not is_flash_v2_installed(): pytest.skip( - f'norm_type=triton_rmsnorm requires flash Attention to be installed' + f'norm_type=triton_rmsnorm requires flash Attention to be installed', ) # Test that the config constructs the model as expected. @@ -772,13 +933,17 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, assert mpt.config.ffn_config['ffn_hidden_size'] == ffn_hidden_size assert mpt.config.max_seq_len == 2048 - assert mpt.transformer.wte.weight.shape == torch.Size( - [hf_config.vocab_size, hf_config.d_model]) + assert mpt.transformer.wte.weight.shape == torch.Size([ + hf_config.vocab_size, + hf_config.d_model, + ]) if not tie_word_embeddings: assert mpt.lm_head is not None assert mpt.lm_head.weight.shape == mpt.transformer.wte.weight.shape - assert mpt.transformer.wpe.weight.shape == torch.Size( - [hf_config.max_seq_len, hf_config.d_model]) + assert mpt.transformer.wpe.weight.shape == torch.Size([ + hf_config.max_seq_len, + hf_config.d_model, + ]) assert mpt.transformer.emb_drop.p == 0.1 assert len(mpt.transformer.blocks) == 2 @@ -791,40 +956,47 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, assert block.norm_2 is not None assert block.norm_2.weight.shape == torch.Size([d_model]) assert isinstance(block.ffn.up_proj, nn.Linear) - assert block.ffn.up_proj.weight.shape == torch.Size( - [ffn_hidden_size, hf_config.d_model]) + assert block.ffn.up_proj.weight.shape == torch.Size([ + ffn_hidden_size, + hf_config.d_model, + ]) assert isinstance(block.ffn.down_proj, nn.Linear) - assert block.ffn.down_proj.weight.shape == torch.Size( - [hf_config.d_model, ffn_hidden_size]) + assert block.ffn.down_proj.weight.shape == torch.Size([ + hf_config.d_model, + ffn_hidden_size, + ]) assert block.resid_attn_dropout.p == 0.2 assert block.resid_ffn_dropout.p == 0.2 @pytest.mark.gpu @pytest.mark.parametrize('attention_impl', ['flash', 'torch']) -@pytest.mark.parametrize('pos_emb_config', [{ - 'alibi': True, - 'rope': False -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'hf', - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, -}]) +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, + }], +) def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): # Testing the output of concatenated sequence with sequence id masking vs individual sequences. alibi = pos_emb_config['alibi'] @@ -832,15 +1004,17 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') rope = pos_emb_config['rope'] - if rope and pos_emb_config[ - 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + if rope and pos_emb_config['rope_impl' + ] == 'dail' and not is_flash_v2_installed(): pytest.skip( - f'dail implementation of rope requires gpu and flash attention 2.') + f'dail implementation of rope requires gpu and flash attention 2.', + ) if attention_impl == 'flash' and ( - not is_flash_v2_installed(v2_version='v2.1.2')): + not is_flash_v2_installed(v2_version='v2.1.2') + ): pytest.skip( - 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' + 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.', ) composer_device = get_device(None) @@ -868,13 +1042,15 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): mpt.eval() mpt = composer_device.module_to_device(mpt) - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context( + 'amp_bf16' if composer_device.name == 'gpu' else 'fp32', + ): # padding on the right side of the input concatenated_seq_ids = torch.tensor([[11274, 16390, 11, 4332, 323, 423], [2342, 12, 111, 123, 50256, 342]]) concatenated_seq_ids = composer_device.tensor_to_device( - concatenated_seq_ids) + concatenated_seq_ids, + ) sequence_id = torch.tensor([[0, 0, 0, 1, 2, 2], [0, 0, 0, 1, 2, 2]]) sequence_id = composer_device.tensor_to_device(sequence_id) @@ -888,72 +1064,90 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): third_seq_ids = torch.tensor([[323, 423], [50256, 342]]) third_seq_ids = composer_device.tensor_to_device(third_seq_ids) - concatenated_seq_output = mpt(concatenated_seq_ids, - sequence_id=sequence_id).logits + concatenated_seq_output = mpt( + concatenated_seq_ids, + sequence_id=sequence_id, + ).logits first_seq_output = mpt(first_seq_ids).logits second_seq_output = mpt(second_seq_ids).logits third_seq_output = mpt(third_seq_ids).logits - assert torch.allclose(concatenated_seq_output[:, :3], - first_seq_output, - atol=2e-6 if attention_impl == 'torch' else 1e-8) - assert torch.allclose(concatenated_seq_output[:, 3:4], - second_seq_output, - atol=2e-6 if attention_impl == 'torch' else 1e-8) + assert torch.allclose( + concatenated_seq_output[:, :3], + first_seq_output, + atol=2e-6 if attention_impl == 'torch' else 1e-8, + ) + assert torch.allclose( + concatenated_seq_output[:, 3:4], + second_seq_output, + atol=2e-6 if attention_impl == 'torch' else 1e-8, + ) atol = 1e-8 if attention_impl == 'torch': atol = 2e-6 elif pos_emb_config['rope']: atol = 2e-2 - assert torch.allclose(concatenated_seq_output[:, 4:6], - third_seq_output, - atol=atol) - - -@pytest.mark.parametrize('attention_impl', [ - 'torch', - pytest.param('flash', marks=pytest.mark.gpu), - pytest.param('torch', marks=pytest.mark.gpu) -]) -@pytest.mark.parametrize('pos_emb_config', [{ - 'alibi': False, - 'rope': False -}, { - 'alibi': True, - 'rope': False -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'hf', - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, -}]) + assert torch.allclose( + concatenated_seq_output[:, 4:6], + third_seq_output, + atol=atol, + ) + + +@pytest.mark.parametrize( + 'attention_impl', + [ + 'torch', + pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('torch', marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': False, + 'rope': False, + }, { + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, + }], +) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_forward_with_padding(attention_impl: str, pos_emb_config: dict, - tie_word_embeddings: bool): +def test_forward_with_padding( + attention_impl: str, + pos_emb_config: dict, + tie_word_embeddings: bool, +): # Test that different placement of padding does not affect the output. alibi = pos_emb_config['alibi'] if alibi and not check_alibi_support(attention_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') rope = pos_emb_config['rope'] - if rope and pos_emb_config[ - 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + if rope and pos_emb_config['rope_impl' + ] == 'dail' and not is_flash_v2_installed(): pytest.skip( - f'dail implementation of rope requires gpu and flash attention 2.') + f'dail implementation of rope requires gpu and flash attention 2.', + ) composer_device = get_device(None) @@ -980,65 +1174,92 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict, mpt.eval() mpt = composer_device.module_to_device(mpt) - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context( + 'amp_bf16' if composer_device.name == 'gpu' else 'fp32', + ): # padding on the right side of the input - right_padding_input_ids = torch.tensor( - [[11274, 16390, 11, 50256, 50256, 50256], - [11274, 16390, 11, 50256, 50256, 50256]]) + right_padding_input_ids = torch.tensor([[ + 11274, + 16390, + 11, + 50256, + 50256, + 50256, + ], [11274, 16390, 11, 50256, 50256, 50256]]) right_padding_input_ids = composer_device.tensor_to_device( - right_padding_input_ids) + right_padding_input_ids, + ) right_padding_attention_mask = torch.tensor([[1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, - 0]]).bool() + [1, 1, 1, 0, 0, 0]]).bool() right_padding_attention_mask = composer_device.tensor_to_device( - right_padding_attention_mask) + right_padding_attention_mask, + ) # padding in the middle of the input - middle_padding_input_ids = torch.tensor( - [[11274, 16390, 50256, 50256, 50256, 11], - [11274, 16390, 50256, 50256, 50256, 11]]) + middle_padding_input_ids = torch.tensor([[ + 11274, + 16390, + 50256, + 50256, + 50256, + 11, + ], [11274, 16390, 50256, 50256, 50256, 11]]) middle_padding_input_ids = composer_device.tensor_to_device( - middle_padding_input_ids) + middle_padding_input_ids, + ) middle_padding_attention_mask = torch.tensor([[1, 1, 0, 0, 0, 1], [1, 1, 0, 0, 0, 1]]).bool() middle_padding_attention_mask = composer_device.tensor_to_device( - middle_padding_attention_mask) + middle_padding_attention_mask, + ) # padding on the left side of the input - left_padding_input_ids = torch.tensor( - [[50256, 50256, 50256, 11274, 16390, 11], - [50256, 50256, 50256, 11274, 16390, 11]]) + left_padding_input_ids = torch.tensor([[ + 50256, + 50256, + 50256, + 11274, + 16390, + 11, + ], [50256, 50256, 50256, 11274, 16390, 11]]) left_padding_input_ids = composer_device.tensor_to_device( - left_padding_input_ids) + left_padding_input_ids, + ) left_padding_attention_mask = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]]).bool() left_padding_attention_mask = composer_device.tensor_to_device( - left_padding_attention_mask) + left_padding_attention_mask, + ) # a single batch with padding in different places batched_input_ids = torch.tensor([ [11274, 16390, 11, 50256, 50256, 50256], # right padding - [11274, 16390, 50256, 50256, 50256, 11] + [11274, 16390, 50256, 50256, 50256, 11], ]) # middle padding batched_input_ids = composer_device.tensor_to_device(batched_input_ids) batched_attention_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 0, 0, 0, 1]]).bool() batched_attention_mask = composer_device.tensor_to_device( - batched_attention_mask) + batched_attention_mask, + ) right_padding_output = mpt( right_padding_input_ids, - attention_mask=right_padding_attention_mask).logits + attention_mask=right_padding_attention_mask, + ).logits middle_padding_output = mpt( middle_padding_input_ids, - attention_mask=middle_padding_attention_mask).logits + attention_mask=middle_padding_attention_mask, + ).logits left_padding_output = mpt( left_padding_input_ids, - attention_mask=left_padding_attention_mask).logits - batched_output = mpt(batched_input_ids, - attention_mask=batched_attention_mask).logits + attention_mask=left_padding_attention_mask, + ).logits + batched_output = mpt( + batched_input_ids, + attention_mask=batched_attention_mask, + ).logits # check that right padding and left padding produce the same output right_pad_v_left_pad_rtol = 1e-5 @@ -1047,10 +1268,12 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict, # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. right_pad_v_left_pad_rtol = 2e-2 right_pad_v_left_pad_atol = 2e-2 - assert torch.allclose(right_padding_output[0, :3], - left_padding_output[0, 3:], - rtol=right_pad_v_left_pad_rtol, - atol=right_pad_v_left_pad_atol) + assert torch.allclose( + right_padding_output[0, :3], + left_padding_output[0, 3:], + rtol=right_pad_v_left_pad_rtol, + atol=right_pad_v_left_pad_atol, + ) if not (alibi or (rope and pos_emb_config['rope_impl'] == 'dail')): # check that right padding and middle padding produce the same output @@ -1059,12 +1282,15 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict, assert torch.allclose( right_padding_output[0, :3], middle_padding_output[0, [0, 1, 5]], - atol=1e-6 if attention_impl == 'torch' else 1e-8) + atol=1e-6 if attention_impl == 'torch' else 1e-8, + ) # check that right padding and right padding in a batch produce the same output - assert torch.allclose(right_padding_output[0, :3], - batched_output[0, :3], - atol=1e-6 if attention_impl == 'torch' else 1e-8) + assert torch.allclose( + right_padding_output[0, :3], + batched_output[0, :3], + atol=1e-6 if attention_impl == 'torch' else 1e-8, + ) if not (alibi or (rope and pos_emb_config['rope_impl'] == 'dail')): # check that middle padding and middle padding in a batch produce the same output @@ -1073,7 +1299,8 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict, assert torch.allclose( middle_padding_output[0], batched_output[1, :], - atol=1e-6 if attention_impl == 'torch' else 1e-8) + atol=1e-6 if attention_impl == 'torch' else 1e-8, + ) try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip @@ -1088,77 +1315,96 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict, right_padding_output_pad_flipped = mpt( right_padding_input_ids, - attention_mask=right_padding_attention_mask).logits + attention_mask=right_padding_attention_mask, + ).logits middle_padding_output_pad_flipped = mpt( middle_padding_input_ids, - attention_mask=middle_padding_attention_mask).logits + attention_mask=middle_padding_attention_mask, + ).logits left_padding_output_pad_flipped = mpt( left_padding_input_ids, - attention_mask=left_padding_attention_mask).logits + attention_mask=left_padding_attention_mask, + ).logits pad_vs_unpad_rtol = 1e-5 pad_vs_unpad_atol = 1e-6 - assert torch.allclose(right_padding_output[0, :3], - right_padding_output_pad_flipped[0, :3], - rtol=pad_vs_unpad_rtol, - atol=pad_vs_unpad_atol) - - assert torch.allclose(middle_padding_output[0, [0, 1, 5]], - middle_padding_output_pad_flipped[0, - [0, 1, 5]], - rtol=pad_vs_unpad_rtol, - atol=pad_vs_unpad_atol) - - assert torch.allclose(left_padding_output[0, 3:], - left_padding_output_pad_flipped[0, 3:], - rtol=pad_vs_unpad_rtol, - atol=pad_vs_unpad_atol) - - -@pytest.mark.parametrize('attention_impl,precision', [ - ('torch', 'fp32'), - pytest.param('flash', 'amp_bf16', marks=pytest.mark.gpu), - pytest.param('torch', 'amp_bf16', marks=pytest.mark.gpu), - pytest.param('torch', 'fp32', marks=pytest.mark.gpu), -]) -@pytest.mark.parametrize('pos_emb_config', [{ - 'alibi': False, - 'rope': False -}, { - 'alibi': True, - 'rope': False -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'hf', - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, -}]) + assert torch.allclose( + right_padding_output[0, :3], + right_padding_output_pad_flipped[0, :3], + rtol=pad_vs_unpad_rtol, + atol=pad_vs_unpad_atol, + ) + + assert torch.allclose( + middle_padding_output[0, [0, 1, 5]], + middle_padding_output_pad_flipped[0, [0, 1, 5]], + rtol=pad_vs_unpad_rtol, + atol=pad_vs_unpad_atol, + ) + + assert torch.allclose( + left_padding_output[0, 3:], + left_padding_output_pad_flipped[0, 3:], + rtol=pad_vs_unpad_rtol, + atol=pad_vs_unpad_atol, + ) + + +@pytest.mark.parametrize( + 'attention_impl,precision', + [ + ('torch', 'fp32'), + pytest.param('flash', 'amp_bf16', marks=pytest.mark.gpu), + pytest.param('torch', 'amp_bf16', marks=pytest.mark.gpu), + pytest.param('torch', 'fp32', marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': False, + 'rope': False, + }, { + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, + }], +) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, - tie_word_embeddings: bool): +def test_generate( + attention_impl: str, + precision: str, + pos_emb_config: dict, + tie_word_embeddings: bool, +): # Test that generate works, and produces the same output with or without # padding in the input. if pos_emb_config['alibi'] and not check_alibi_support(attention_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip( - f'dail implementation of rope requires gpu and flash attention 2.') + f'dail implementation of rope requires gpu and flash attention 2.', + ) if attention_impl == 'torch' and precision == 'amp_bf16' and tie_word_embeddings == False: pytest.skip(f'This test configuration has precision / sampling issues.') @@ -1184,24 +1430,33 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, mpt.eval() # padding on the left of the input - left_padding_input_ids = torch.tensor( - [[50256, 50256, 50256, 11274, 16390, 11], - [50256, 50256, 50256, 11274, 16390, 11]]) + left_padding_input_ids = torch.tensor([[ + 50256, + 50256, + 50256, + 11274, + 16390, + 11, + ], [50256, 50256, 50256, 11274, 16390, 11]]) left_padding_input_ids = composer_device.tensor_to_device( - left_padding_input_ids) + left_padding_input_ids, + ) left_padding_attention_mask = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]]) left_padding_attention_mask = composer_device.tensor_to_device( - left_padding_attention_mask) + left_padding_attention_mask, + ) # no padding in the input no_padding_input_ids = torch.tensor([[11274, 16390, 11], [11274, 16390, 11]]) no_padding_input_ids = composer_device.tensor_to_device( - no_padding_input_ids) + no_padding_input_ids, + ) no_padding_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]) no_padding_attention_mask = composer_device.tensor_to_device( - no_padding_attention_mask) + no_padding_attention_mask, + ) # inputs_embeds inputs_embeds = composer_device.tensor_to_device(torch.randn(2, 3, 128)) @@ -1213,61 +1468,78 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, batched_attention_mask = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 1, 1, 1, 1]]).bool() batched_attention_mask = composer_device.tensor_to_device( - batched_attention_mask) + batched_attention_mask, + ) with get_precision_context(precision): # check that a batch with different amounts of padding doesn't crash # and produces the right output shape - batched_generation = mpt.generate(input_ids=batched_input_ids, - attention_mask=batched_attention_mask, - max_new_tokens=5, - use_cache=False) + batched_generation = mpt.generate( + input_ids=batched_input_ids, + attention_mask=batched_attention_mask, + max_new_tokens=5, + use_cache=False, + ) assert batched_generation.shape == (2, 6 + 5) generation_with_left_padding = mpt.generate( input_ids=left_padding_input_ids, attention_mask=left_padding_attention_mask, max_new_tokens=5, - use_cache=False) + use_cache=False, + ) assert generation_with_left_padding.shape == (2, 6 + 5) generation_with_no_padding = mpt.generate( input_ids=no_padding_input_ids, attention_mask=no_padding_attention_mask, max_new_tokens=5, - use_cache=False) + use_cache=False, + ) assert generation_with_no_padding.shape == (2, 3 + 5) # check that left padding and no padding produce the same output assert generation_with_no_padding[:, 3:].equal( - generation_with_left_padding[:, 6:]) + generation_with_left_padding[:, 6:], + ) # check that both/neither ids and embeds do not error # note that we need to set the BOS token ID for generating from neither - _ = mpt.generate(input_ids=no_padding_input_ids, - inputs_embeds=inputs_embeds, - attention_mask=no_padding_attention_mask, - max_new_tokens=5, - use_cache=False) - _ = mpt.generate(input_ids=no_padding_input_ids, - inputs_embeds=inputs_embeds, - attention_mask=no_padding_attention_mask, - max_new_tokens=5, - use_cache=True) - _ = mpt.generate(input_ids=None, - max_new_tokens=5, - use_cache=False, - bos_token_id=50256) - _ = mpt.generate(input_ids=None, - max_new_tokens=5, - use_cache=True, - bos_token_id=50256) + _ = mpt.generate( + input_ids=no_padding_input_ids, + inputs_embeds=inputs_embeds, + attention_mask=no_padding_attention_mask, + max_new_tokens=5, + use_cache=False, + ) + _ = mpt.generate( + input_ids=no_padding_input_ids, + inputs_embeds=inputs_embeds, + attention_mask=no_padding_attention_mask, + max_new_tokens=5, + use_cache=True, + ) + _ = mpt.generate( + input_ids=None, + max_new_tokens=5, + use_cache=False, + bos_token_id=50256, + ) + _ = mpt.generate( + input_ids=None, + max_new_tokens=5, + use_cache=True, + bos_token_id=50256, + ) @pytest.mark.gpu @pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, - tie_word_embeddings: bool): +def test_generate_with_device_map( + tmp_path: pathlib.Path, + world_size: int, + tie_word_embeddings: bool, +): if not torch.cuda.device_count() >= world_size: pytest.skip(f'This test requires {world_size} GPUs.') @@ -1321,8 +1593,10 @@ def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, ) -def check_hf_model_equivalence(model1: PreTrainedModel, - model2: PreTrainedModel): +def check_hf_model_equivalence( + model1: PreTrainedModel, + model2: PreTrainedModel, +): # Checks that two huggingface models are equivalent (config and # parameters) expected_model_config_dict = model1.config.to_dict() @@ -1333,11 +1607,12 @@ def check_hf_model_equivalence(model1: PreTrainedModel, del new_model_config_dict['_name_or_path'] assert expected_model_config_dict == new_model_config_dict - assert sum(p.numel() for p in model1.parameters()) == sum( - p.numel() for p in model2.parameters()) + assert sum(p.numel() for p in model1.parameters() + ) == sum(p.numel() for p in model2.parameters()) assert all( type(module1) == type(module2) - for module1, module2 in zip(model1.modules(), model2.modules())) + for module1, module2 in zip(model1.modules(), model2.modules()) + ) for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p1, p2) @@ -1367,44 +1642,51 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): check_hf_model_equivalence(mpt, mpt2) -@pytest.mark.parametrize('attn_impl', [ - 'torch', - pytest.param('flash', marks=pytest.mark.gpu), -]) -@pytest.mark.parametrize('pos_emb_config', [{ - 'alibi': False, - 'rope': False -}, { - 'alibi': True, - 'rope': False -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'hf', - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, -}]) +@pytest.mark.parametrize( + 'attn_impl', + [ + 'torch', + pytest.param('flash', marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': False, + 'rope': False, + }, { + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, + }], +) def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): # Tests that the result is the same with or without padding when using kv caching if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip( - f'dail implementation of rope requires gpu and flash attention 2.') + f'dail implementation of rope requires gpu and flash attention 2.', + ) composer_device = get_device(None) @@ -1432,116 +1714,148 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) mpt.eval() - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context( + 'amp_bf16' if composer_device.name == 'gpu' else 'fp32', + ): first_input_ids_no_padding = torch.tensor([[11274, 16390, 11]]) first_input_ids_no_padding = composer_device.tensor_to_device( - first_input_ids_no_padding) + first_input_ids_no_padding, + ) first_attention_mask_no_padding = torch.tensor([[1, 1, 1]]).bool() first_attention_mask_no_padding = composer_device.tensor_to_device( - first_attention_mask_no_padding) + first_attention_mask_no_padding, + ) # start with passing the first three tokens through (no padding) first_output_no_padding = mpt( first_input_ids_no_padding, - attention_mask=first_attention_mask_no_padding) + attention_mask=first_attention_mask_no_padding, + ) second_input_ids_no_padding = torch.tensor([[11274, 16390, 11, 11274]]) second_input_ids_no_padding = composer_device.tensor_to_device( - second_input_ids_no_padding) + second_input_ids_no_padding, + ) second_attention_mask_no_padding = torch.tensor([[1, 1, 1, 1]]).bool() second_attention_mask_no_padding = composer_device.tensor_to_device( - second_attention_mask_no_padding) + second_attention_mask_no_padding, + ) # pass through the fourth token by itself, using the key-value cache (no padding) second_output_no_padding = mpt( second_input_ids_no_padding[:, -1].unsqueeze(-1), attention_mask=second_attention_mask_no_padding, - past_key_values=first_output_no_padding.past_key_values) + past_key_values=first_output_no_padding.past_key_values, + ) first_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11]]) first_input_ids_padding = composer_device.tensor_to_device( - first_input_ids_padding) + first_input_ids_padding, + ) first_attention_mask_padding = torch.tensor([[0, 1, 1, 1]]).bool() first_attention_mask_padding = composer_device.tensor_to_device( - first_attention_mask_padding) + first_attention_mask_padding, + ) # start with passing the first three tokens through (with left padding) - first_output_padding = mpt(first_input_ids_padding, - attention_mask=first_attention_mask_padding) + first_output_padding = mpt( + first_input_ids_padding, + attention_mask=first_attention_mask_padding, + ) - second_input_ids_padding = torch.tensor( - [[50256, 11274, 16390, 11, 11274]]) + second_input_ids_padding = torch.tensor([[ + 50256, + 11274, + 16390, + 11, + 11274, + ]]) second_input_ids_padding = composer_device.tensor_to_device( - second_input_ids_padding) + second_input_ids_padding, + ) second_attention_mask_padding = torch.tensor([[0, 1, 1, 1, 1]]).bool() second_attention_mask_padding = composer_device.tensor_to_device( - second_attention_mask_padding) + second_attention_mask_padding, + ) # pass through the fourth token by itself, using the key-value cache (with left padding) second_output_padding = mpt( second_input_ids_padding[:, -1].unsqueeze(-1), attention_mask=second_attention_mask_padding, - past_key_values=first_output_padding.past_key_values) + past_key_values=first_output_padding.past_key_values, + ) # check that the outputs are the same with or without padding if pos_emb_config['rope'] and pos_emb_config[ - 'rope_impl'] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. + 'rope_impl' + ] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. torch.testing.assert_close( second_output_no_padding.logits, second_output_padding.logits[:, -1, :].unsqueeze(1), atol=1e-2, - rtol=1e-6) + rtol=1e-6, + ) else: torch.testing.assert_close( second_output_no_padding.logits, second_output_padding.logits[:, -1, :].unsqueeze(1), atol=1e-6, - rtol=1e-6) - - -@pytest.mark.parametrize('attn_impl', [ - 'torch', - pytest.param('flash', marks=pytest.mark.gpu), -]) -@pytest.mark.parametrize('pos_emb_config', [{ - 'alibi': False, - 'rope': False -}, { - 'alibi': True, - 'rope': False -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'hf', - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, -}]) + rtol=1e-6, + ) + + +@pytest.mark.parametrize( + 'attn_impl', + [ + 'torch', + pytest.param('flash', marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': False, + 'rope': False, + }, { + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, + }], +) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_forward_with_cache(attn_impl: str, pos_emb_config: dict, - tie_word_embeddings: bool): +def test_forward_with_cache( + attn_impl: str, + pos_emb_config: dict, + tie_word_embeddings: bool, +): # Test that model forward with and without the key-value cache produces the # same output. if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip( - f'dail implementation of rope requires gpu and flash attention 2.') + f'dail implementation of rope requires gpu and flash attention 2.', + ) composer_device = get_device(None) @@ -1569,13 +1883,15 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict, mpt = composer_device.module_to_device(mpt) mpt.eval() - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context( + 'amp_bf16' if composer_device.name == 'gpu' else 'fp32', + ): first_input_ids = torch.tensor([[11274, 16390, 11]]) first_input_ids = composer_device.tensor_to_device(first_input_ids) first_attention_mask = torch.tensor([[1, 1, 1]]).bool() first_attention_mask = composer_device.tensor_to_device( - first_attention_mask) + first_attention_mask, + ) # start with passing the first three tokens through first_output = mpt(first_input_ids, attention_mask=first_attention_mask) @@ -1584,23 +1900,33 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict, assert len(first_output.past_key_values) == hf_config.n_layers assert all( len(past_key_value) == 2 - for past_key_value in first_output.past_key_values) + for past_key_value in first_output.past_key_values + ) if attn_impl == 'torch': - assert all(past_key_value[0].shape == (1, 4, 32, 3) - for past_key_value in first_output.past_key_values) - assert all(past_key_value[1].shape == (1, 4, 3, 32) - for past_key_value in first_output.past_key_values) + assert all( + past_key_value[0].shape == (1, 4, 32, 3) + for past_key_value in first_output.past_key_values + ) + assert all( + past_key_value[1].shape == (1, 4, 3, 32) + for past_key_value in first_output.past_key_values + ) else: - assert all(past_key_value[0].shape == (1, 3, 128) - for past_key_value in first_output.past_key_values) - assert all(past_key_value[1].shape == (1, 3, 128) - for past_key_value in first_output.past_key_values) + assert all( + past_key_value[0].shape == (1, 3, 128) + for past_key_value in first_output.past_key_values + ) + assert all( + past_key_value[1].shape == (1, 3, 128) + for past_key_value in first_output.past_key_values + ) second_input_ids = torch.tensor([[11274, 16390, 11, 11274]]) second_input_ids = composer_device.tensor_to_device(second_input_ids) second_attention_mask = torch.tensor([[1, 1, 1, 1]]).bool() second_attention_mask = composer_device.tensor_to_device( - second_attention_mask) + second_attention_mask, + ) # pass through the fourth token by itself, using the key-value cache second_output = mpt( @@ -1613,21 +1939,32 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict, assert len(second_output.past_key_values) == hf_config.n_layers assert all( len(past_key_value) == 2 - for past_key_value in second_output.past_key_values) + for past_key_value in second_output.past_key_values + ) if attn_impl == 'torch': - assert all(past_key_value[0].shape == (1, 4, 32, 4) - for past_key_value in second_output.past_key_values) - assert all(past_key_value[1].shape == (1, 4, 4, 32) - for past_key_value in second_output.past_key_values) + assert all( + past_key_value[0].shape == (1, 4, 32, 4) + for past_key_value in second_output.past_key_values + ) + assert all( + past_key_value[1].shape == (1, 4, 4, 32) + for past_key_value in second_output.past_key_values + ) else: - assert all(past_key_value[0].shape == (1, 4, 128) - for past_key_value in second_output.past_key_values) - assert all(past_key_value[1].shape == (1, 4, 128) - for past_key_value in second_output.past_key_values) + assert all( + past_key_value[0].shape == (1, 4, 128) + for past_key_value in second_output.past_key_values + ) + assert all( + past_key_value[1].shape == (1, 4, 128) + for past_key_value in second_output.past_key_values + ) # pass through the first four tokens without the key-value cache - full_output = mpt(second_input_ids, - attention_mask=second_attention_mask) + full_output = mpt( + second_input_ids, + attention_mask=second_attention_mask, + ) # check that the output is the same whether using the key-value cache or not torch.testing.assert_close( @@ -1638,45 +1975,55 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict, ) -@pytest.mark.parametrize('attn_impl', [ - 'torch', - pytest.param('flash', marks=pytest.mark.gpu), -]) -@pytest.mark.parametrize('pos_emb_config', [{ - 'alibi': False, - 'rope': False -}, { - 'alibi': True, - 'rope': False -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'hf', - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, -}]) +@pytest.mark.parametrize( + 'attn_impl', + [ + 'torch', + pytest.param('flash', marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': False, + 'rope': False, + }, { + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, + }], +) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict, - tie_word_embeddings: bool): +def test_generate_with_past_kv( + attn_impl: str, + pos_emb_config: dict, + tie_word_embeddings: bool, +): if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip( - f'dail implementation of rope requires gpu and flash attention 2.') + f'dail implementation of rope requires gpu and flash attention 2.', + ) composer_device = get_device(None) @@ -1707,24 +2054,35 @@ def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict, # no padding in the input no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) no_padding_input_ids = composer_device.tensor_to_device( - no_padding_input_ids) + no_padding_input_ids, + ) no_padding_attention_mask = torch.tensor([[1, 1, 1]]) no_padding_attention_mask = composer_device.tensor_to_device( - no_padding_attention_mask) + no_padding_attention_mask, + ) - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): - with mock.patch.object(MPTForCausalLM, 'forward', - autospec=True) as forward_mocked: + with get_precision_context( + 'amp_bf16' if composer_device.name == 'gpu' else 'fp32', + ): + with mock.patch.object( + MPTForCausalLM, + 'forward', + autospec=True, + ) as forward_mocked: forward_mocked.return_value = CausalLMOutputWithPast( logits=composer_device.tensor_to_device( - torch.randn((1, 3, hf_config.vocab_size))), - past_key_values=[(torch.randn(1, 3, hf_config.d_model), - torch.randn(1, 3, hf_config.d_model)) - for _ in range(hf_config.n_layers)]) - _ = mpt.generate(input_ids=no_padding_input_ids, - attention_mask=no_padding_attention_mask, - max_new_tokens=2) + torch.randn((1, 3, hf_config.vocab_size)), + ), + past_key_values=[( + torch.randn(1, 3, hf_config.d_model), + torch.randn(1, 3, hf_config.d_model), + ) for _ in range(hf_config.n_layers)], + ) + _ = mpt.generate( + input_ids=no_padding_input_ids, + attention_mask=no_padding_attention_mask, + max_new_tokens=2, + ) assert forward_mocked.call_count == 2 _, _, kwargs = forward_mocked.mock_calls[0] @@ -1732,58 +2090,73 @@ def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict, _, _, kwargs = forward_mocked.mock_calls[1] assert kwargs['past_key_values'] is not None assert len(kwargs['past_key_values']) == hf_config.n_layers - assert kwargs['past_key_values'][0][0].shape == (1, 3, - hf_config.d_model) - - -@pytest.mark.parametrize('attn_impl', [ - 'torch', - pytest.param('flash', marks=pytest.mark.gpu), -]) -@pytest.mark.parametrize('generation_kwargs', [{ - 'max_new_tokens': 2, - 'num_beams': 4, - 'top_k': 5, - 'penalty_alpha': 0.4 -}]) -@pytest.mark.parametrize('pos_emb_config', [{ - 'alibi': False, - 'rope': False -}, { - 'alibi': True, - 'rope': False -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'hf', - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, -}]) + assert kwargs['past_key_values'][0][0].shape == ( + 1, + 3, + hf_config.d_model, + ) + + +@pytest.mark.parametrize( + 'attn_impl', + [ + 'torch', + pytest.param('flash', marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize( + 'generation_kwargs', + [{ + 'max_new_tokens': 2, + 'num_beams': 4, + 'top_k': 5, + 'penalty_alpha': 0.4, + }], +) +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': False, + 'rope': False, + }, { + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, + }], +) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_generation_kwargs_dont_crash(attn_impl: str, - generation_kwargs: Dict[str, Any], - pos_emb_config: dict, - tie_word_embeddings: bool): +def test_generation_kwargs_dont_crash( + attn_impl: str, + generation_kwargs: Dict[str, Any], + pos_emb_config: dict, + tie_word_embeddings: bool, +): if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip( - f'dail implementation of rope requires gpu and flash attention 2.') + f'dail implementation of rope requires gpu and flash attention 2.', + ) composer_device = get_device(None) if composer_device.name == 'gpu': @@ -1809,55 +2182,63 @@ def test_generation_kwargs_dont_crash(attn_impl: str, mpt = composer_device.module_to_device(mpt) mpt.eval() - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context( + 'amp_bf16' if composer_device.name == 'gpu' else 'fp32', + ): no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) no_padding_input_ids = composer_device.tensor_to_device( - no_padding_input_ids) + no_padding_input_ids, + ) no_padding_attention_mask = torch.tensor([[1, 1, 1]]) no_padding_attention_mask = composer_device.tensor_to_device( - no_padding_attention_mask) + no_padding_attention_mask, + ) - _ = mpt.generate(input_ids=no_padding_input_ids, - attention_mask=no_padding_attention_mask, - **generation_kwargs) + _ = mpt.generate( + input_ids=no_padding_input_ids, + attention_mask=no_padding_attention_mask, + **generation_kwargs, + ) if composer_device.name == 'gpu': reproducibility.configure_deterministic_mode() @pytest.mark.gpu -@pytest.mark.parametrize('pos_emb_config', [{ - 'alibi': False, - 'rope': False -}, { - 'alibi': True, - 'rope': False -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'hf', - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, -}]) +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': False, + 'rope': False, + }, { + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, + }], +) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_model_to(pos_emb_config: dict, tie_word_embeddings: bool): # test that moving the model to diff devices and dtypes in diff ways does not break the model if pos_emb_config['rope'] and pos_emb_config[ - 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip(f'dail implementation of rope requires flash attention 2.') hf_config = MPTConfig( @@ -1896,8 +2277,10 @@ def test_model_to(pos_emb_config: dict, tie_word_embeddings: bool): # verify the model still works if not (pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'): with torch.autocast('cpu', dtype=torch.bfloat16, enabled=True): - _ = mpt(input_ids.to('cpu'), - attention_mask=attention_mask.to('cpu')) + _ = mpt( + input_ids.to('cpu'), + attention_mask=attention_mask.to('cpu'), + ) mpt = mpt.float() @@ -1919,60 +2302,76 @@ def test_alibi_vs_hf(): for seq_len in [1, 2, 8, 13, 64, 195, 256]: # hf bloom alibi bais alibi_bias_hf = build_alibi_tensor( - torch.ones(seq_len)[None, ...], n_heads, torch.float32) + torch.ones(seq_len)[None, ...], + n_heads, + torch.float32, + ) alibi_bias_hf = alibi_bias_hf - alibi_bias_hf.max( - dim=2, keepdim=True).values + dim=2, + keepdim=True, + ).values # mosaicml alibi bais - alibi_bias_m = build_alibi_bias(n_heads, - seq_len, - dtype=torch.float32) + alibi_bias_m = build_alibi_bias( + n_heads, + seq_len, + dtype=torch.float32, + ) alibi_bias_m = alibi_bias_m[0] torch.testing.assert_close(alibi_bias_hf, alibi_bias_m) -@pytest.mark.parametrize('attn_impl', [ - 'torch', - pytest.param('flash', marks=pytest.mark.gpu), - pytest.param('torch', marks=pytest.mark.gpu), -]) -@pytest.mark.parametrize('pos_emb_config', [{ - 'alibi': False, - 'rope': False -}, { - 'alibi': True, - 'rope': False -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, -}, { - 'alibi': False, - 'rope': True, - 'rope_theta': 10000, - 'rope_impl': 'hf', - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, -}]) +@pytest.mark.parametrize( + 'attn_impl', + [ + 'torch', + pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('torch', marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize( + 'pos_emb_config', + [{ + 'alibi': False, + 'rope': False, + }, { + 'alibi': True, + 'rope': False, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + }, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, + }], +) def test_forward_with_output_attentions_and_output_hidden_states( - attn_impl: str, pos_emb_config: dict): + attn_impl: str, + pos_emb_config: dict, +): if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if attn_impl == 'flash': pytest.skip(f'output_attentions only implemented with torch attention.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip( - f'dail implementation of rope requires gpu and flash attention 2.') + f'dail implementation of rope requires gpu and flash attention 2.', + ) composer_device = get_device(None) @@ -2002,8 +2401,9 @@ def test_forward_with_output_attentions_and_output_hidden_states( mpt = composer_device.module_to_device(mpt) mpt.eval() - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context( + 'amp_bf16' if composer_device.name == 'gpu' else 'fp32', + ): input_ids = torch.tensor([[11274, 16390, 11]]) input_ids = composer_device.tensor_to_device(input_ids) attention_mask = torch.tensor([[1, 1, 1]]).bool() @@ -2024,10 +2424,12 @@ def test_forward_with_output_attentions_and_output_hidden_states( @pytest.mark.gpu @pytest.mark.parametrize('init_device', ['cpu', 'meta', 'mixed']) @pytest.mark.parametrize('world_size', [2]) -def test_hf_init(tmp_path: pathlib.Path, - init_device: str, - world_size: int, - batch_size: int = 1): +def test_hf_init( + tmp_path: pathlib.Path, + init_device: str, + world_size: int, + batch_size: int = 1, +): if not torch.cuda.device_count() >= world_size: pytest.skip(f'This test requires {world_size} GPUs.') @@ -2078,16 +2480,22 @@ def test_hf_init(tmp_path: pathlib.Path, # Load in a pretrained model with a given context with context: - model = AutoModelForCausalLM.from_pretrained(save_path, - trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + save_path, + trust_remote_code=True, + ) tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) - tokenizer = build_tokenizer(test_cfg.tokenizer.name, - tokenizer_cfg.get('kwargs', {})) + tokenizer = build_tokenizer( + test_cfg.tokenizer.name, + tokenizer_cfg.get('kwargs', {}), + ) - optimizer = DecoupledAdamW(model.parameters(), - lr=1e-5, - betas=tuple([0.9, 0.99])) + optimizer = DecoupledAdamW( + model.parameters(), + lr=1e-5, + betas=(0.9, 0.99), + ) prepare_fsdp_module(model, optimizer, fsdp_config, precision, device, False) @@ -2127,14 +2535,16 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): resid_pdrop=0.2, attn_config={ 'attn_impl': 'flash', - 'attn_type': 'multiquery_attention' + 'attn_type': 'multiquery_attention', }, ) test_cfg.device = torch.cuda.current_device() tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) - tokenizer = build_tokenizer(test_cfg.tokenizer.name, - tokenizer_cfg.get('kwargs', {})) + tokenizer = build_tokenizer( + test_cfg.tokenizer.name, + tokenizer_cfg.get('kwargs', {}), + ) mpt = MPTForCausalLM(hf_config) @@ -2143,8 +2553,10 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): model = model.to(test_cfg.device) batch = gen_random_batch(batch_size, test_cfg) - assert batch['input_ids'].shape == torch.Size( - [batch_size, test_cfg.max_seq_len]) + assert batch['input_ids'].shape == torch.Size([ + batch_size, + test_cfg.max_seq_len, + ]) model.train() diff --git a/tests/models/test_mpt_gen.py b/tests/models/test_mpt_gen.py index 00c6a1c7a8..1c9b5ef9d4 100644 --- a/tests/models/test_mpt_gen.py +++ b/tests/models/test_mpt_gen.py @@ -14,8 +14,10 @@ from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase -from llmfoundry.models.mpt.modeling_mpt import (ComposerMPTCausalLM, - MPTForCausalLM) +from llmfoundry.models.mpt.modeling_mpt import ( + ComposerMPTCausalLM, + MPTForCausalLM, +) EOS_TOKEN_ID = 0 @@ -36,10 +38,18 @@ def forward( use_cache: Optional[bool] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ): - result = super().forward(input_ids, past_key_values, attention_mask, - sequence_id, labels, return_dict, - output_attentions, output_hidden_states, - use_cache, inputs_embeds) + result = super().forward( + input_ids, + past_key_values, + attention_mask, + sequence_id, + labels, + return_dict, + output_attentions, + output_hidden_states, + use_cache, + inputs_embeds, + ) # Modify the logits to select the next token. if dist.get_global_rank() == 0: # Rank 0 hits EOS immediately. @@ -55,13 +65,17 @@ def forward( @pytest.mark.parametrize('attn_impl', ['flash', 'torch']) @pytest.mark.parametrize('use_alibi', [True, False]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -@patch('llmfoundry.models.mpt.modeling_mpt.MPTForCausalLM', - new=MockMPTForCausalLM) -def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, - tie_word_embeddings: bool, - build_tiny_mpt: Callable[..., - ComposerMPTCausalLM], - mpt_tokenizer: PreTrainedTokenizerBase): +@patch( + 'llmfoundry.models.mpt.modeling_mpt.MPTForCausalLM', + new=MockMPTForCausalLM, +) +def test_mpt_generate_multi_gpu( + attn_impl: str, + use_alibi: bool, + tie_word_embeddings: bool, + build_tiny_mpt: Callable[..., ComposerMPTCausalLM], + mpt_tokenizer: PreTrainedTokenizerBase, +): """Tests mpt generation with mutiple gpus. and generations of different lengths. @@ -73,7 +87,7 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, attn_config={ 'attn_impl': attn_impl, 'attn_uses_sequence_id': False, - 'alibi': use_alibi + 'alibi': use_alibi, }, ) model = device.module_to_device(model) @@ -83,21 +97,26 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, model.model = FSDP(model.model) with get_precision_context('amp_bf16'): - _ = model.generate(device.tensor_to_device( - mpt_tokenizer('hello', return_tensors='pt')['input_ids']), - max_new_tokens=3, - eos_token_id=EOS_TOKEN_ID, - use_cache=True, - synced_gpus=True) + _ = model.generate( + device.tensor_to_device( + mpt_tokenizer('hello', return_tensors='pt')['input_ids'], + ), + max_new_tokens=3, + eos_token_id=EOS_TOKEN_ID, + use_cache=True, + synced_gpus=True, + ) @pytest.mark.gpu @pytest.mark.parametrize('attn_impl', ['flash', 'torch']) @pytest.mark.parametrize('use_alibi', [True, False]) -def test_mpt_generate_callback(attn_impl: str, use_alibi: bool, - build_tiny_mpt: Callable[..., - ComposerMPTCausalLM], - tiny_ft_dataloader: DataLoader): +def test_mpt_generate_callback( + attn_impl: str, + use_alibi: bool, + build_tiny_mpt: Callable[..., ComposerMPTCausalLM], + tiny_ft_dataloader: DataLoader, +): device = get_device('gpu') # build mpt model @@ -106,7 +125,7 @@ def test_mpt_generate_callback(attn_impl: str, use_alibi: bool, attn_config={ 'attn_impl': attn_impl, 'attn_uses_sequence_id': False, - 'alibi': use_alibi + 'alibi': use_alibi, }, ) model = device.module_to_device(model) @@ -177,11 +196,13 @@ def test_gen_mpt_moe( model.eval() - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context( + 'amp_bf16' if composer_device.name == 'gpu' else 'fp32', + ): _ = model.generate( composer_device.tensor_to_device( - mpt_tokenizer('hello', return_tensors='pt')['input_ids']), + mpt_tokenizer('hello', return_tensors='pt')['input_ids'], + ), max_new_tokens=10, ) @@ -190,9 +211,11 @@ def test_gen_mpt_moe( @pytest.mark.parametrize('attn_impl', ['flash', 'torch']) @pytest.mark.parametrize('use_alibi', [True, False]) def test_mpt_generate_callback_not_tied( - use_alibi: bool, attn_impl: str, - build_tiny_mpt: Callable[..., ComposerMPTCausalLM], - tiny_ft_dataloader: DataLoader): + use_alibi: bool, + attn_impl: str, + build_tiny_mpt: Callable[..., ComposerMPTCausalLM], + tiny_ft_dataloader: DataLoader, +): device = get_device('gpu') # build mpt model diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index becd3c773f..95732cfd8f 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -21,7 +21,7 @@ def gen_random_batch(batch_size: int, vocab_size: int, max_seq_len: int): dtype=torch.int64, ), 'attention_mask': - torch.ones(size=(batch_size, max_seq_len), dtype=torch.bool) + torch.ones(size=(batch_size, max_seq_len), dtype=torch.bool), } return batch diff --git a/tests/models/test_rmsnorm_triton_vs_eager.py b/tests/models/test_rmsnorm_triton_vs_eager.py index 7169c5d926..c8f0a2e07f 100644 --- a/tests/models/test_rmsnorm_triton_vs_eager.py +++ b/tests/models/test_rmsnorm_triton_vs_eager.py @@ -13,12 +13,15 @@ @pytest.mark.gpu @pytest.mark.parametrize('normalized_shape', [32, 128, 4096]) -def test_rmsnorm_triton_vs_eager(normalized_shape: Union[int, List[int]], - device: str = 'cuda'): +def test_rmsnorm_triton_vs_eager( + normalized_shape: Union[int, List[int]], + device: str = 'cuda', +): # Compare Triton and PyTorch Eager implementations of RMSNorm if not is_flash_v2_installed(): pytest.skip( - 'triton implementation of rmsnorm requires flash attention 2.') + 'triton implementation of rmsnorm requires flash attention 2.', + ) batch_size = 2 @@ -38,7 +41,7 @@ def test_rmsnorm_triton_vs_eager(normalized_shape: Union[int, List[int]], if isinstance(normalized_shape, int): input_shape = [batch_size, normalized_shape] else: - input_shape = tuple([batch_size, *normalized_shape]) + input_shape = (batch_size, *normalized_shape) x0 = torch.randn(size=input_shape, device=device) x1 = x0.clone().detach() diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index b9ab90357a..6a41e64f48 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -8,14 +8,17 @@ from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.layers.layer_builders import build_attention_layer -from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info, - gen_rotary_embedding) +from llmfoundry.models.mpt.modeling_mpt import ( + gen_flash_attn_padding_info, + gen_rotary_embedding, +) @pytest.mark.gpu @pytest.mark.parametrize( 'attn_type', - ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) + ['multihead_attention', 'multiquery_attention', 'grouped_query_attention'], +) @pytest.mark.parametrize('seq_len', [1, 233, 2048]) def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): # compare rope rotations for the dail vs hf implementations @@ -38,13 +41,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): attn0 = build_attention_layer( name=attn_type, - attn_kwargs=om.to_container( - cfg), # type: ignore (to_container return broad type) + attn_kwargs=om. + to_container(cfg), # type: ignore (to_container return broad type) ).to(device) attn1 = build_attention_layer( name=attn_type, - attn_kwargs=om.to_container( - cfg), # type: ignore (to_container return broad type) + attn_kwargs=om. + to_container(cfg), # type: ignore (to_container return broad type) ).to(device) attn1.load_state_dict(attn0.state_dict()) @@ -62,7 +65,7 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): 'type': 'original', 'pos_idx_in_fp32': True, 'xpos_scale_base': 512, - } + }, } hf_rope_config = { 'rope_theta': 10000, @@ -70,7 +73,7 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, - } + }, } dail_rope = gen_rotary_embedding( @@ -79,7 +82,8 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): rope_theta=dail_rope_config['rope_theta'], rope_dail_config=dail_rope_config['rope_dail_config'], rope_hf_config={}, - max_seq_len=seq_len).to('cuda') + max_seq_len=seq_len, + ).to('cuda') dail_rope_w_meta_info = { 'impl': 'dail', 'rotary_emb': dail_rope, @@ -93,7 +97,8 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): rope_theta=hf_rope_config['rope_theta'], rope_dail_config={}, rope_hf_config=hf_rope_config['rope_hf_config'], - max_seq_len=seq_len).to('cuda') + max_seq_len=seq_len, + ).to('cuda') pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda') # adjust the position indices to account for padding tokens pos = torch.clamp( @@ -107,25 +112,39 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): 'seq_len': seq_len, } - y0, _, _ = attn0(x0, - past_key_value=None, - attn_bias=None, - attention_mask=attention_mask, - rotary_emb_w_meta_info=dail_rope_w_meta_info, - is_causal=True, - flash_attn_padding_info=gen_flash_attn_padding_info( - batch_size, seq_len, 0, torch.device(device), None, - attention_mask)) - - y1, _, _ = attn1(x1, - past_key_value=None, - attn_bias=None, - attention_mask=attention_mask, - rotary_emb_w_meta_info=hf_rope_w_meta_info, - is_causal=True, - flash_attn_padding_info=gen_flash_attn_padding_info( - batch_size, seq_len, 0, torch.device(device), None, - attention_mask)) + y0, _, _ = attn0( + x0, + past_key_value=None, + attn_bias=None, + attention_mask=attention_mask, + rotary_emb_w_meta_info=dail_rope_w_meta_info, + is_causal=True, + flash_attn_padding_info=gen_flash_attn_padding_info( + batch_size, + seq_len, + 0, + torch.device(device), + None, + attention_mask, + ), + ) + + y1, _, _ = attn1( + x1, + past_key_value=None, + attn_bias=None, + attention_mask=attention_mask, + rotary_emb_w_meta_info=hf_rope_w_meta_info, + is_causal=True, + flash_attn_padding_info=gen_flash_attn_padding_info( + batch_size, + seq_len, + 0, + torch.device(device), + None, + attention_mask, + ), + ) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) @@ -138,7 +157,7 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): torch.testing.assert_close(y0, y1, rtol=1e-2, atol=1e-2) - torch_name_param_map = {n: p for n, p in attn1.named_parameters()} + torch_name_param_map = dict(attn1.named_parameters()) for n, p in attn0.named_parameters(): tp = torch_name_param_map[n] assert p.grad is not None diff --git a/tests/models/utils/test_param_init_fns.py b/tests/models/utils/test_param_init_fns.py index 0efc245602..de818304a6 100644 --- a/tests/models/utils/test_param_init_fns.py +++ b/tests/models/utils/test_param_init_fns.py @@ -92,17 +92,27 @@ def init_fn_(weight: torch.Tensor): assert (p == expected_value).all() -@pytest.mark.parametrize('module', [ - nn.Linear(8, 16), - nn.Embedding(8, 16), - pytest.param(nn.LayerNorm(8), - marks=pytest.mark.xfail( - reason='LayerNorm is skipped by init_fn_', strict=True)), - pytest.param(nn.Conv2d(8, 16, 3), - marks=pytest.mark.xfail( - reason='generic_param_init_fn_ does not init Conv layers', - strict=True)), -]) +@pytest.mark.parametrize( + 'module', + [ + nn.Linear(8, 16), + nn.Embedding(8, 16), + pytest.param( + nn.LayerNorm(8), + marks=pytest.mark.xfail( + reason='LayerNorm is skipped by init_fn_', + strict=True, + ), + ), + pytest.param( + nn.Conv2d(8, 16, 3), + marks=pytest.mark.xfail( + reason='generic_param_init_fn_ does not init Conv layers', + strict=True, + ), + ), + ], +) def test_all_params_init(module: torch.nn.Module): fill_val = torch.finfo(torch.float16).max @@ -114,8 +124,9 @@ def max_fill_init_(weight: torch.Tensor): cfg = om.create({ 'n_layers': 2, }) - module.apply(partial(generic_param_init_fn_, init_fn_=max_fill_init_, - **cfg)) + module.apply( + partial(generic_param_init_fn_, init_fn_=max_fill_init_, **cfg), + ) for n, p in module.named_parameters(): if n == 'bias': assert (p == 0).all() @@ -123,11 +134,18 @@ def max_fill_init_(weight: torch.Tensor): assert (p == fill_val).all() -@pytest.mark.parametrize('emb_init_cfg', [ - None, ('emb_init_std', 5), ('emb_init_std', 0), ('emb_init_uniform_lim', 2), - ('emb_init_uniform_lim', [-1, 4]), ('emb_init_uniform_lim', 0), - ('emb_init_uniform_lim', [1, 1]) -]) +@pytest.mark.parametrize( + 'emb_init_cfg', + [ + None, + ('emb_init_std', 5), + ('emb_init_std', 0), + ('emb_init_uniform_lim', 2), + ('emb_init_uniform_lim', [-1, 4]), + ('emb_init_uniform_lim', 0), + ('emb_init_uniform_lim', [1, 1]), + ], +) def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]): cfg: Dict[str, Union[int, List[int]]] = { 'vocab_size': 64, @@ -142,14 +160,26 @@ def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]): model = nn.Sequential( OrderedDict([ ('emb', nn.Embedding(dict_cfg.vocab_size, dict_cfg.in_features)), - ('fc1', - nn.Linear(dict_cfg.in_features, dict_cfg.out_features, bias=True)), + ( + 'fc1', + nn.Linear( + dict_cfg.in_features, + dict_cfg.out_features, + bias=True, + ), + ), ('ln1', nn.LayerNorm(dict_cfg.out_features)), ('act1', nn.ReLU()), - ('fc2', - nn.Linear(dict_cfg.out_features, dict_cfg.out_features, - bias=True)), - ])) + ( + 'fc2', + nn.Linear( + dict_cfg.out_features, + dict_cfg.out_features, + bias=True, + ), + ), + ]), + ) model.apply(partial(param_init_fns.get('kaiming_normal_'), **dict_cfg)) @@ -165,6 +195,7 @@ def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]): assert (model.emb.weight == 0).all() elif isinstance(emb_init_uniform_lim, Sequence): assert len(emb_init_uniform_lim) <= 2 - if len(emb_init_uniform_lim - ) == 2 and emb_init_uniform_lim[0] == emb_init_uniform_lim[1]: + if len( + emb_init_uniform_lim, + ) == 2 and emb_init_uniform_lim[0] == emb_init_uniform_lim[1]: assert (model.emb.weight == emb_init_uniform_lim[0]).all() diff --git a/tests/optim/test_scheduler.py b/tests/optim/test_scheduler.py index 811088bd62..602a4ef8c7 100644 --- a/tests/optim/test_scheduler.py +++ b/tests/optim/test_scheduler.py @@ -34,80 +34,114 @@ def dummy_schedulers_state(request: pytest.FixtureRequest): return state -@pytest.mark.parametrize('scheduler,ssr,test_times,expected_lrs', [ - pytest.param( - InverseSquareRootWithWarmupScheduler(t_warmup='10ba', - t_scale='10ba', - t_cooldown='0ba', - alpha_f_decay=0, - alpha_f_cooldown=0), 1.0, - ['0ba', '5ba', '10ba', '40ba', '90ba', '100ba'], - [0.0, 0.5, 1.0, 0.5, 0.33333, 0.31623]), - pytest.param( - InverseSquareRootWithWarmupScheduler(t_warmup='20ba', - t_scale='2ba', - t_cooldown='10ba', - alpha_f_decay=0.4, - alpha_f_cooldown=0.1), 1.0, - ['0ba', '10ba', '20ba', '36ba', '90ba', '95ba', '100ba'], - [0.0, 0.5, 1.0, 0.6, 0.5, 0.3, 0.1]), -]) -def test_scheduler_init(scheduler: ComposerScheduler, ssr: float, - test_times: List[str], expected_lrs: List[float], - dummy_schedulers_state: State): +@pytest.mark.parametrize( + 'scheduler,ssr,test_times,expected_lrs', + [ + pytest.param( + InverseSquareRootWithWarmupScheduler( + t_warmup='10ba', + t_scale='10ba', + t_cooldown='0ba', + alpha_f_decay=0, + alpha_f_cooldown=0, + ), + 1.0, + ['0ba', '5ba', '10ba', '40ba', '90ba', '100ba'], + [0.0, 0.5, 1.0, 0.5, 0.33333, 0.31623], + ), + pytest.param( + InverseSquareRootWithWarmupScheduler( + t_warmup='20ba', + t_scale='2ba', + t_cooldown='10ba', + alpha_f_decay=0.4, + alpha_f_cooldown=0.1, + ), + 1.0, + ['0ba', '10ba', '20ba', '36ba', '90ba', '95ba', '100ba'], + [0.0, 0.5, 1.0, 0.6, 0.5, 0.3, 0.1], + ), + ], +) +def test_scheduler_init( + scheduler: ComposerScheduler, + ssr: float, + test_times: List[str], + expected_lrs: List[float], + dummy_schedulers_state: State, +): state = dummy_schedulers_state assert state.dataloader_len is not None assert state.max_duration is not None - state.max_duration = Time(value=int(state.max_duration.value * ssr), - unit=state.max_duration.unit) + state.max_duration = Time( + value=int(state.max_duration.value * ssr), + unit=state.max_duration.unit, + ) for test_time, expected_lr in zip(test_times, expected_lrs): parsed_time = Time.from_timestring(test_time) assert parsed_time.unit in [TimeUnit.EPOCH, TimeUnit.BATCH] state.timestamp = state.timestamp.copy( batch=parsed_time, epoch=Time( - int(parsed_time) // int(state.dataloader_len), TimeUnit.EPOCH), + int(parsed_time) // int(state.dataloader_len), + TimeUnit.EPOCH, + ), ) lr = scheduler(state, ssr) assert lr == pytest.approx(expected_lr, abs=1e-3) -@pytest.mark.parametrize('state_unit,warmup_unit,scale_unit,cooldown_unit', [ - ['ep', 'ba', 'ba', 'ba'], - ['ba', 'ep', 'ep', 'ep'], - ['ep', 'ep', 'ba', 'ep'], -]) -def test_scheduler_units_match_error(state_unit: str, warmup_unit: str, - scale_unit: str, cooldown_unit: str, - dummy_schedulers_state: State): +@pytest.mark.parametrize( + 'state_unit,warmup_unit,scale_unit,cooldown_unit', + [ + ['ep', 'ba', 'ba', 'ba'], + ['ba', 'ep', 'ep', 'ep'], + ['ep', 'ep', 'ba', 'ep'], + ], +) +def test_scheduler_units_match_error( + state_unit: str, + warmup_unit: str, + scale_unit: str, + cooldown_unit: str, + dummy_schedulers_state: State, +): state = dummy_schedulers_state state.max_duration = f'1{state_unit}' scheduler = InverseSquareRootWithWarmupScheduler( t_warmup=f'10{warmup_unit}', t_scale=f'10{scale_unit}', - t_cooldown=f'10{cooldown_unit}') + t_cooldown=f'10{cooldown_unit}', + ) with pytest.raises(ValueError, match='must match'): _ = scheduler(state, 1.0) -@pytest.mark.parametrize('warmup_unit,scale_unit,cooldown_unit', [ - ['dur', 'ba', 'ba'], - ['ba', 'dur', 'ba'], - ['ba', 'ba', 'dur'], -]) +@pytest.mark.parametrize( + 'warmup_unit,scale_unit,cooldown_unit', + [ + ['dur', 'ba', 'ba'], + ['ba', 'dur', 'ba'], + ['ba', 'ba', 'dur'], + ], +) def test_unit_dur_error(warmup_unit: str, scale_unit: str, cooldown_unit: str): with pytest.raises(ValueError, match='cannot be in units of "dur".'): - _ = InverseSquareRootWithWarmupScheduler(t_warmup=f'1{warmup_unit}', - t_scale=f'1{scale_unit}', - t_cooldown=f'1{cooldown_unit}') + _ = InverseSquareRootWithWarmupScheduler( + t_warmup=f'1{warmup_unit}', + t_scale=f'1{scale_unit}', + t_cooldown=f'1{cooldown_unit}', + ) def test_alpha_f_error(): with pytest.raises(ValueError, match='alpha_f_decay >= alpha_f_cooldown.'): - _ = InverseSquareRootWithWarmupScheduler(t_warmup='10ba', - t_scale='10ba', - t_cooldown='10ba', - alpha_f_decay=0.0, - alpha_f_cooldown=0.1) + _ = InverseSquareRootWithWarmupScheduler( + t_warmup='10ba', + t_scale='10ba', + t_cooldown='10ba', + alpha_f_decay=0.0, + alpha_f_cooldown=0.1, + ) diff --git a/tests/test_registry.py b/tests/test_registry.py index d7a1fc7dfe..0db9f832b7 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -47,10 +47,12 @@ def test_expected_registries_exist(): def test_registry_create(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(catalogue, 'Registry', {}) - new_registry = registry_utils.create_registry('llmfoundry', - 'test_registry', - generic_type=str, - entry_points=False) + new_registry = registry_utils.create_registry( + 'llmfoundry', + 'test_registry', + generic_type=str, + entry_points=False, + ) assert new_registry.namespace == ('llmfoundry', 'test_registry') assert isinstance(new_registry, registry_utils.TypedRegistry) @@ -58,10 +60,12 @@ def test_registry_create(monkeypatch: pytest.MonkeyPatch): def test_registry_typing(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(catalogue, 'Registry', {}) - new_registry = registry_utils.create_registry('llmfoundry', - 'test_registry', - generic_type=str, - entry_points=False) + new_registry = registry_utils.create_registry( + 'llmfoundry', + 'test_registry', + generic_type=str, + entry_points=False, + ) new_registry.register('test_name', func='test') # This would fail type checking without the type ignore @@ -72,10 +76,12 @@ def test_registry_typing(monkeypatch: pytest.MonkeyPatch): def test_registry_add(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(catalogue, 'Registry', {}) - new_registry = registry_utils.create_registry('llmfoundry', - 'test_registry', - generic_type=str, - entry_points=False) + new_registry = registry_utils.create_registry( + 'llmfoundry', + 'test_registry', + generic_type=str, + entry_points=False, + ) new_registry.register('test_name', func='test') assert new_registry.get('test_name') == 'test' @@ -83,10 +89,12 @@ def test_registry_add(monkeypatch: pytest.MonkeyPatch): def test_registry_overwrite(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(catalogue, 'Registry', {}) - new_registry = registry_utils.create_registry('llmfoundry', - 'test_registry', - generic_type=str, - entry_points=False) + new_registry = registry_utils.create_registry( + 'llmfoundry', + 'test_registry', + generic_type=str, + entry_points=False, + ) new_registry.register('test_name', func='test') new_registry.register('test_name', func='test2') @@ -121,20 +129,30 @@ def test_registry_entrypoint(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(catalogue, 'Registry', {}) monkeypatch.setattr( - importlib.metadata, 'entry_points', lambda: { + importlib.metadata, + 'entry_points', + lambda: { 'llmfoundry_test_registry': [ - EntryPoint(name='test_entry', - value='composer.loggers:InMemoryLogger', - group='llmfoundry_test_registry') - ] - }) - - monkeypatch.setattr(catalogue, 'AVAILABLE_ENTRY_POINTS', - importlib.metadata.entry_points()) - new_registry = registry_utils.create_registry('llmfoundry', - 'test_registry', - generic_type=str, - entry_points=True) + EntryPoint( + name='test_entry', + value='composer.loggers:InMemoryLogger', + group='llmfoundry_test_registry', + ), + ], + }, + ) + + monkeypatch.setattr( + catalogue, + 'AVAILABLE_ENTRY_POINTS', + importlib.metadata.entry_points(), + ) + new_registry = registry_utils.create_registry( + 'llmfoundry', + 'test_registry', + generic_type=str, + entry_points=True, + ) assert new_registry.get('test_entry') == InMemoryLogger @@ -146,7 +164,8 @@ def test_registry_builder(monkeypatch: pytest.MonkeyPatch): 'test_registry', entry_points=False, generic_type=Union[Type[LoggerDestination], - Callable[..., LoggerDestination]]) + Callable[..., LoggerDestination]], + ) class TestLoggerDestination(LoggerDestination): pass @@ -157,16 +176,20 @@ class TestLoggerDestination(LoggerDestination): valid_class = registry_utils.construct_from_registry( 'test_destination', new_registry, - pre_validation_function=TestLoggerDestination) + pre_validation_function=TestLoggerDestination, + ) assert isinstance(valid_class, TestLoggerDestination) # Invalid, class validation - with pytest.raises(ValueError, - match='Expected test_destination to be of type'): + with pytest.raises( + ValueError, + match='Expected test_destination to be of type', + ): registry_utils.construct_from_registry( 'test_destination', new_registry, - pre_validation_function=InMemoryLogger) + pre_validation_function=InMemoryLogger, + ) # Invalid, function pre-validation with pytest.raises(ValueError, match='Invalid'): @@ -177,7 +200,8 @@ def pre_validation_function(x: Any): registry_utils.construct_from_registry( 'test_destination', new_registry, - pre_validation_function=pre_validation_function) + pre_validation_function=pre_validation_function, + ) # Invalid, function post-validation with pytest.raises(ValueError, match='Invalid'): @@ -188,27 +212,38 @@ def post_validation_function(x: Any): registry_utils.construct_from_registry( 'test_destination', new_registry, - post_validation_function=post_validation_function) + post_validation_function=post_validation_function, + ) # Invalid, not a class or function new_registry.register('non_callable', func=1) # type: ignore - with pytest.raises(ValueError, - match='Expected non_callable to be a class or function'): + with pytest.raises( + ValueError, + match='Expected non_callable to be a class or function', + ): registry_utils.construct_from_registry('non_callable', new_registry) # Valid, partial function - new_registry.register('partial_func', - func=lambda x, y: x * y) # type: ignore - partial_func = registry_utils.construct_from_registry('partial_func', - new_registry, - partial_function=True, - kwargs={'x': 2}) + new_registry.register( + 'partial_func', + func=lambda x, + y: x * y, + ) # type: ignore + partial_func = registry_utils.construct_from_registry( + 'partial_func', + new_registry, + partial_function=True, + kwargs={'x': 2}, + ) assert partial_func(y=3) == 6 # Valid, builder function new_registry.register('builder_func', func=lambda: TestLoggerDestination()) valid_built_class = registry_utils.construct_from_registry( - 'builder_func', new_registry, partial_function=False) + 'builder_func', + new_registry, + partial_function=False, + ) assert isinstance(valid_built_class, TestLoggerDestination) assert os.environ['TEST_ENVIRON_REGISTRY_KEY'] == 'test' diff --git a/tests/tokenizers/test_tiktoken.py b/tests/tokenizers/test_tiktoken.py index bb936db617..af18c73927 100644 --- a/tests/tokenizers/test_tiktoken.py +++ b/tests/tokenizers/test_tiktoken.py @@ -7,8 +7,10 @@ import pytest import transformers -from llmfoundry.tokenizers.tiktoken import (TiktokenTokenizerWrapper, - bytes_to_unicode) +from llmfoundry.tokenizers.tiktoken import ( + TiktokenTokenizerWrapper, + bytes_to_unicode, +) from tests.a_scripts.inference.test_convert_composer_to_hf import \ check_hf_tokenizer_equivalence from tests.horrible_strings import HORRIBLE_STRINGS @@ -17,16 +19,22 @@ from tiktoken.core import Encoding TEST_STRINGS = [ - 'Hello world!', 'def hello_world(input: str):\n print(input)', + 'Hello world!', + 'def hello_world(input: str):\n print(input)', '0000000000000000000000000000', - '19234324 asas sf 119aASDFM AW3RAW-AF;;9900', '\n\n\n\nhello\n\t', + '19234324 asas sf 119aASDFM AW3RAW-AF;;9900', + '\n\n\n\nhello\n\t', ' hello\n\t\\\\ goodbye!?*#&@!) ', 'This is just a normal sentence. And here is another one!', - 'hello<|endoftext|>world', 'hello <|endoftext|> world', - 'hello <|endoftext|>', 'hello <|endoftext|> ', '<|endoftext}>', - '<|endoftext}> ', ' <|endoftext|>', + 'hello<|endoftext|>world', + 'hello <|endoftext|> world', + 'hello <|endoftext|>', + 'hello <|endoftext|> ', + '<|endoftext}>', + '<|endoftext}> ', + ' <|endoftext|>', '<|endoftext|><|endoftext|><|endoftext|><|endoftext|>', - '<|endoftext|> <|endoftext|> <|endoftext|> <|endoftext|>' + '<|endoftext|> <|endoftext|> <|endoftext|> <|endoftext|>', ] TEST_STRINGS += HORRIBLE_STRINGS @@ -46,25 +54,25 @@ 'content': 'Please summarize the goals in this text:\n\nGoing outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.', 'role': - 'user' + 'user', }, { 'content': 'You should go outside and touch grass.', - 'role': 'assistant' + 'role': 'assistant', }], [{ 'content': 'You are a honest and helpful AI language model. Tell the user the truth, the whole truth, and nothing but the truth.', 'role': - 'system' + 'system', }, { 'content': 'Please summarize the goals in this text:\n\nGoing outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.', 'role': - 'user' + 'user', }, { 'content': 'You should go outside and touch grass.', - 'role': 'assistant' - }] + 'role': 'assistant', + }], ] MULTI_TURN_CHAT_STRING_NO_SYSTEM_PROMPT = [ @@ -73,14 +81,15 @@ Going outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.<|im_end|> <|im_start|>assistant -You should go outside and touch grass.<|im_end|>""", """<|im_start|>system +You should go outside and touch grass.<|im_end|>""", + """<|im_start|>system You are a honest and helpful AI language model. Tell the user the truth, the whole truth, and nothing but the truth.<|im_end|> <|im_start|>user Please summarize the goals in this text: Going outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.<|im_end|> <|im_start|>assistant -You should go outside and touch grass.<|im_end|>""" +You should go outside and touch grass.<|im_end|>""", ] MULTI_TURN_CHAT_STRING_SYSTEM_PROMPT = [ @@ -91,27 +100,28 @@ Going outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.<|im_end|> <|im_start|>assistant -You should go outside and touch grass.<|im_end|>""", """<|im_start|>system +You should go outside and touch grass.<|im_end|>""", + """<|im_start|>system You are a honest and helpful AI language model. Tell the user the truth, the whole truth, and nothing but the truth.<|im_end|> <|im_start|>user Please summarize the goals in this text: Going outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.<|im_end|> <|im_start|>assistant -You should go outside and touch grass.<|im_end|>""" +You should go outside and touch grass.<|im_end|>""", ] MULTI_TURN_GENERATE_CHAT_ML = [[{ 'content': 'Please summarize the goals in this text:\n\nGoing outside has benefits include reducing stress and triggering the relaxation response, which can help us not only feel better mentally, but even heal faster from physical ailments.', 'role': - 'user' + 'user', }, { 'content': 'You should go outside and touch grass.', - 'role': 'assistant' + 'role': 'assistant', }, { 'content': 'What else can I do?', - 'role': 'user' + 'role': 'user', }]] MULTI_TURN_GENERATE_STRING = [ @@ -126,7 +136,7 @@ <|im_start|>user What else can I do?<|im_end|> <|im_start|>assistant -""" +""", ] @@ -148,7 +158,8 @@ def get_tokenizers_for_testing( add_bos_token=add_bos_token, add_eos_token=add_eos_token, use_default_system_prompt=use_default_system_prompt, - additional_special_tokens=additional_special_tokens) + additional_special_tokens=additional_special_tokens, + ) if model_name is not None: original_tokenizer = tiktoken.encoding_for_model(model_name) else: @@ -160,23 +171,35 @@ def get_tokenizers_for_testing( # Save and load wrapped_tokenizer.save_pretrained(tmp_path) reloaded_wrapped_tokenizer = transformers.AutoTokenizer.from_pretrained( - tmp_path, trust_remote_code=True) + tmp_path, + trust_remote_code=True, + ) return wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer -@pytest.mark.parametrize('model_name,encoding_name', - MODEL_ENCODING_NAME_PARAMETRIZATION) -def test_tiktoken_simple(model_name: Optional[str], - encoding_name: Optional[str], tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION, +) +def test_tiktoken_simple( + model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path, +): wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( - model_name, encoding_name, tmp_path) + model_name, + encoding_name, + tmp_path, + ) # Simple tokenization test for string in TEST_STRINGS: wrapped_output = wrapped_tokenizer(string) - original_output = original_tokenizer.encode(string, - allowed_special='all') + original_output = original_tokenizer.encode( + string, + allowed_special='all', + ) reloaded_wrapped_output = reloaded_wrapped_tokenizer(string) assert wrapped_output['input_ids'] == original_output @@ -184,93 +207,146 @@ def test_tiktoken_simple(model_name: Optional[str], assert reloaded_wrapped_output == wrapped_output -@pytest.mark.parametrize('model_name,encoding_name', - MODEL_ENCODING_NAME_PARAMETRIZATION) -def test_tiktoken_tokenize_with_ids(model_name: Optional[str], - encoding_name: Optional[str], - tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION, +) +def test_tiktoken_tokenize_with_ids( + model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path, +): wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( - model_name, encoding_name, tmp_path) + model_name, + encoding_name, + tmp_path, + ) for string in TEST_STRINGS: wrapped_output = wrapped_tokenizer.tokenize(string) - original_output = original_tokenizer.encode(string, - allowed_special='all') + original_output = original_tokenizer.encode( + string, + allowed_special='all', + ) reloaded_wrapped_output = reloaded_wrapped_tokenizer.tokenize(string) - assert all([isinstance(t, str) for t in wrapped_output]) + assert all(isinstance(t, str) for t in wrapped_output) assert len(wrapped_output) == len(original_output) assert wrapped_output == reloaded_wrapped_output redone_token_ids = wrapped_tokenizer.convert_tokens_to_ids( - wrapped_output) + wrapped_output, + ) assert redone_token_ids == original_output assert wrapped_tokenizer.convert_ids_to_tokens( - redone_token_ids) == wrapped_output + redone_token_ids, + ) == wrapped_output -@pytest.mark.parametrize('model_name,encoding_name', - MODEL_ENCODING_NAME_PARAMETRIZATION) -def test_tiktoken_roundtrip(model_name: Optional[str], - encoding_name: Optional[str], - tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION, +) +def test_tiktoken_roundtrip( + model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path, +): wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( - model_name, encoding_name, tmp_path) + model_name, + encoding_name, + tmp_path, + ) for string in TEST_STRINGS: wrapped_output = wrapped_tokenizer.decode( - wrapped_tokenizer(string)['input_ids']) + wrapped_tokenizer(string)['input_ids'], + ) original_output = original_tokenizer.decode( - original_tokenizer.encode(string, allowed_special='all')) + original_tokenizer.encode(string, allowed_special='all'), + ) reloaded_wrapped_output = reloaded_wrapped_tokenizer.decode( - reloaded_wrapped_tokenizer(string)['input_ids']) + reloaded_wrapped_tokenizer(string)['input_ids'], + ) assert wrapped_output == string assert original_output == string assert reloaded_wrapped_output == string -@pytest.mark.parametrize('model_name,encoding_name', - MODEL_ENCODING_NAME_PARAMETRIZATION) -def test_tiktoken_batched(model_name: Optional[str], - encoding_name: Optional[str], tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION, +) +def test_tiktoken_batched( + model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path, +): wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( - model_name, encoding_name, tmp_path) - - wrapped_output = wrapped_tokenizer( - ['Hello world!', 'Hello world but longer!']) - original_output = original_tokenizer.encode_batch( - ['Hello world!', 'Hello world but longer!']) - reloaded_wrapped_output = reloaded_wrapped_tokenizer( - ['Hello world!', 'Hello world but longer!']) + model_name, + encoding_name, + tmp_path, + ) + + wrapped_output = wrapped_tokenizer([ + 'Hello world!', + 'Hello world but longer!', + ]) + original_output = original_tokenizer.encode_batch([ + 'Hello world!', + 'Hello world but longer!', + ]) + reloaded_wrapped_output = reloaded_wrapped_tokenizer([ + 'Hello world!', + 'Hello world but longer!', + ]) assert wrapped_output['input_ids'] == original_output assert set(wrapped_output.keys()) == {'input_ids', 'attention_mask'} assert reloaded_wrapped_output == wrapped_output assert wrapped_tokenizer.batch_decode( - wrapped_output['input_ids']) == original_tokenizer.decode_batch( - original_output) + wrapped_output['input_ids'], + ) == original_tokenizer.decode_batch(original_output) assert reloaded_wrapped_tokenizer.batch_decode( - reloaded_wrapped_output['input_ids'] + reloaded_wrapped_output['input_ids'], ) == original_tokenizer.decode_batch(original_output) -@pytest.mark.parametrize('model_name,encoding_name', - MODEL_ENCODING_NAME_PARAMETRIZATION) -def test_tiktoken_padding(model_name: Optional[str], - encoding_name: Optional[str], tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION, +) +def test_tiktoken_padding( + model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path, +): wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( - model_name, encoding_name, tmp_path) + model_name, + encoding_name, + tmp_path, + ) wrapped_tokenizer.pad_token_id = wrapped_tokenizer.eos_token_id reloaded_wrapped_tokenizer.pad_token_id = reloaded_wrapped_tokenizer.eos_token_id - wrapped_output = wrapped_tokenizer( - ['Hello world!', 'Hello world but longer!'], padding=True) - original_output = original_tokenizer.encode_batch( - ['Hello world!', 'Hello world but longer!']) - reloaded_wrapped_output = reloaded_wrapped_tokenizer( - ['Hello world!', 'Hello world but longer!'], padding=True) - for wrapped1, attn_mask, original1 in zip(wrapped_output['input_ids'], - wrapped_output['attention_mask'], - original_output): + wrapped_output = wrapped_tokenizer([ + 'Hello world!', + 'Hello world but longer!', + ], + padding=True) + original_output = original_tokenizer.encode_batch([ + 'Hello world!', + 'Hello world but longer!', + ]) + reloaded_wrapped_output = reloaded_wrapped_tokenizer([ + 'Hello world!', + 'Hello world but longer!', + ], + padding=True) + for wrapped1, attn_mask, original1 in zip( + wrapped_output['input_ids'], + wrapped_output['attention_mask'], + original_output, + ): original_length = len(original1) assert wrapped1[:original_length] == original1 assert sum(attn_mask) == original_length @@ -279,12 +355,20 @@ def test_tiktoken_padding(model_name: Optional[str], assert reloaded_wrapped_output == wrapped_output -@pytest.mark.parametrize('model_name,encoding_name', - MODEL_ENCODING_NAME_PARAMETRIZATION) -def test_tiktoken_vocab(model_name: Optional[str], encoding_name: Optional[str], - tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION, +) +def test_tiktoken_vocab( + model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path, +): wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( - model_name, encoding_name, tmp_path) + model_name, + encoding_name, + tmp_path, + ) wrapped_vocab = wrapped_tokenizer.get_vocab() reloaded_wrapped_vocab = reloaded_wrapped_tokenizer.get_vocab() @@ -296,41 +380,56 @@ def test_tiktoken_vocab(model_name: Optional[str], encoding_name: Optional[str], continue expected_decoding = ''.join([ - bytes_to_unicode()[ord(char)] - for char in original_tokenizer.decode_single_token_bytes( - value).decode('latin-1') + bytes_to_unicode()[ord(char)] for char in original_tokenizer. + decode_single_token_bytes(value).decode('latin-1') ]) assert expected_decoding == key -@pytest.mark.parametrize('model_name,encoding_name', - MODEL_ENCODING_NAME_PARAMETRIZATION) -def test_tiktoken_save_from_pretrained(model_name: Optional[str], - encoding_name: Optional[str], - tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION, +) +def test_tiktoken_save_from_pretrained( + model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path, +): wrapped_tokenizer, reloaded_wrapped_tokenizer, _ = get_tokenizers_for_testing( - model_name, encoding_name, tmp_path) - check_hf_tokenizer_equivalence(wrapped_tokenizer, - reloaded_wrapped_tokenizer) - - -@pytest.mark.parametrize('model_name,encoding_name', - MODEL_ENCODING_NAME_PARAMETRIZATION) -def test_tiktoken_encode_plus(model_name: Optional[str], - encoding_name: Optional[str], - tmp_path: pathlib.Path): + model_name, + encoding_name, + tmp_path, + ) + check_hf_tokenizer_equivalence( + wrapped_tokenizer, + reloaded_wrapped_tokenizer, + ) + + +@pytest.mark.parametrize( + 'model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION, +) +def test_tiktoken_encode_plus( + model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path, +): # Testing encode_plus which optionally wrap encodes with bos and eos tokens - wrapped_tokenizer, _, _ = get_tokenizers_for_testing(model_name, - encoding_name, - tmp_path, - add_bos_token=True, - add_eos_token=True) + wrapped_tokenizer, _, _ = get_tokenizers_for_testing( + model_name, + encoding_name, + tmp_path, + add_bos_token=True, + add_eos_token=True, + ) for test_string in TEST_STRINGS: encoded_outputs = wrapped_tokenizer.encode_plus( test_string, add_special_tokens=True, - return_special_tokens_mask=True) + return_special_tokens_mask=True, + ) encoded_input_ids = encoded_outputs.input_ids assert encoded_input_ids[0] == wrapped_tokenizer.bos_token_id assert encoded_input_ids[-1] == wrapped_tokenizer.eos_token_id @@ -340,11 +439,15 @@ def test_tiktoken_encode_plus(model_name: Optional[str], assert encoded_special_mask[-1] == 1 -@pytest.mark.parametrize('model_name,encoding_name', - MODEL_ENCODING_NAME_PARAMETRIZATION) -def test_additional_special_tokens(model_name: Optional[str], - encoding_name: Optional[str], - tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION, +) +def test_additional_special_tokens( + model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path, +): special_token_to_add = '<|im_start|>' input_string = special_token_to_add + ' hello' wrapped_tokenizer, _, _ = get_tokenizers_for_testing( @@ -353,39 +456,51 @@ def test_additional_special_tokens(model_name: Optional[str], tmp_path, add_bos_token=False, add_eos_token=False, - additional_special_tokens=[special_token_to_add]) + additional_special_tokens=[special_token_to_add], + ) encoded_outputs = wrapped_tokenizer(input_string)['input_ids'] assert encoded_outputs[0] == wrapped_tokenizer.vocab_size assert len(encoded_outputs) == 2 decoded_outputs = wrapped_tokenizer.decode( - encoded_outputs, spaces_between_special_tokens=False) + encoded_outputs, + spaces_between_special_tokens=False, + ) assert decoded_outputs == input_string def test_additional_special_tokens_len(): special_token_to_add = '<|im_start|>' with_special = TiktokenTokenizerWrapper( - model_name='gpt-4', additional_special_tokens=[special_token_to_add]) + model_name='gpt-4', + additional_special_tokens=[special_token_to_add], + ) no_special = TiktokenTokenizerWrapper(model_name='gpt-4',) assert len(with_special.get_vocab()) == len(no_special.get_vocab()) + 1 - ret = with_special.add_special_tokens( - {'additional_special_tokens': ['<|im_start|>']}) + ret = with_special.add_special_tokens({ + 'additional_special_tokens': ['<|im_start|>'], + }) assert ret == 0 - ret = with_special.add_special_tokens( - {'additional_special_tokens': ['<|im_end|>']}) + ret = with_special.add_special_tokens({ + 'additional_special_tokens': ['<|im_end|>'], + }) assert ret == 1 assert len(with_special.get_vocab()) == len(no_special.get_vocab()) + 2 -@pytest.mark.parametrize('model_name,encoding_name', - MODEL_ENCODING_NAME_PARAMETRIZATION) -def test_chat_formatting(model_name: Optional[str], - encoding_name: Optional[str], tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION, +) +def test_chat_formatting( + model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path, +): special_tokens_to_add = ['<|im_start|>', ''] # Default behavior to not use default system prompt. wrapped_tokenizer, _, _ = get_tokenizers_for_testing( @@ -394,10 +509,14 @@ def test_chat_formatting(model_name: Optional[str], tmp_path, add_bos_token=False, add_eos_token=False, - additional_special_tokens=special_tokens_to_add) + additional_special_tokens=special_tokens_to_add, + ) for i, dict_chats in enumerate(MULTI_TURN_CHAT_ML): chat_str = wrapped_tokenizer.apply_chat_template( - dict_chats, tokenize=False, add_generation_prompt=False) + dict_chats, + tokenize=False, + add_generation_prompt=False, + ) assert chat_str == MULTI_TURN_CHAT_STRING_NO_SYSTEM_PROMPT[i] # Using default system prompt. wrapped_tokenizer, _, _ = get_tokenizers_for_testing( @@ -407,14 +526,21 @@ def test_chat_formatting(model_name: Optional[str], use_default_system_prompt=True, add_bos_token=False, add_eos_token=False, - additional_special_tokens=special_tokens_to_add) + additional_special_tokens=special_tokens_to_add, + ) for i, dict_chats in enumerate(MULTI_TURN_CHAT_ML): chat_str = wrapped_tokenizer.apply_chat_template( - dict_chats, tokenize=False, add_generation_prompt=False) + dict_chats, + tokenize=False, + add_generation_prompt=False, + ) assert chat_str == MULTI_TURN_CHAT_STRING_SYSTEM_PROMPT[i] for i, dict_chats in enumerate(MULTI_TURN_GENERATE_CHAT_ML): chat_str = wrapped_tokenizer.apply_chat_template( - dict_chats, tokenize=False, add_generation_prompt=True) + dict_chats, + tokenize=False, + add_generation_prompt=True, + ) assert chat_str == MULTI_TURN_GENERATE_STRING[i] diff --git a/tests/tokenizers/test_tokenizer.py b/tests/tokenizers/test_tokenizer.py index 5f1e826177..b4f1846091 100644 --- a/tests/tokenizers/test_tokenizer.py +++ b/tests/tokenizers/test_tokenizer.py @@ -13,17 +13,22 @@ def get_config(conf_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml'): def test_load_tokenizer(): test_cfg = get_config( - conf_path='scripts/train/yamls/pretrain/mpt-125m.yaml') + conf_path='scripts/train/yamls/pretrain/mpt-125m.yaml', + ) truncation = True padding = 'max_length' - resolved_om_tokenizer_config = om.to_container(test_cfg.tokenizer, - resolve=True) + resolved_om_tokenizer_config = om.to_container( + test_cfg.tokenizer, + resolve=True, + ) tokenizer_kwargs = resolved_om_tokenizer_config.get( # type: ignore 'kwargs', {}) tokenizer_name = resolved_om_tokenizer_config['name'] # type: ignore - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, - **tokenizer_kwargs) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + **tokenizer_kwargs, + ) tokenizer.pad_token = tokenizer.eos_token assert tokenizer.vocab_size == 50254 assert tokenizer.name_or_path == 'EleutherAI/gpt-neox-20b' @@ -48,7 +53,8 @@ def test_load_tokenizer(): in_str, truncation=truncation, padding=padding, - max_length=tokenizer.model_max_length)['input_ids'] + max_length=tokenizer.model_max_length, + )['input_ids'] out_pad_tokens = out_token_key + [0] * (tokenizer.model_max_length - 4) assert padded_tokenize == out_pad_tokens @@ -61,7 +67,8 @@ def test_load_tokenizer(): in_str, truncation=truncation, padding=padding, - max_length=tokenizer.model_max_length)['input_ids'] + max_length=tokenizer.model_max_length, + )['input_ids'] assert padded_tokenize == out_pad_tokens # check attn mask @@ -69,6 +76,7 @@ def test_load_tokenizer(): in_str, truncation=truncation, padding=padding, - max_length=tokenizer.model_max_length)['attention_mask'] + max_length=tokenizer.model_max_length, + )['attention_mask'] attn_mask_key = [1, 1, 1, 1] + [0] * (tokenizer.model_max_length - 4) assert attention_mask == attn_mask_key diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 0a4af4538d..f64925e6dd 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -18,23 +18,31 @@ from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper -from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, - build_callback, build_eval_loaders, - build_evaluators, build_logger, - build_optimizer, build_tokenizer) - - -@pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ - ('tiktoken', { - 'model_name': 'gpt-4' - }), - ('EleutherAI/gpt-neo-125M', { - 'model_max_length': 10 - }), - ('mosaicml/mpt-7b', { - 'model_max_length': 20 - }), -]) +from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, + build_callback, + build_eval_loaders, + build_evaluators, + build_logger, + build_optimizer, + build_tokenizer, +) + + +@pytest.mark.parametrize( + 'tokenizer_name,tokenizer_kwargs', + [ + ('tiktoken', { + 'model_name': 'gpt-4', + }), + ('EleutherAI/gpt-neo-125M', { + 'model_max_length': 10, + }), + ('mosaicml/mpt-7b', { + 'model_max_length': 20, + }), + ], +) def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict): tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) @@ -42,15 +50,16 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict): assert isinstance(tokenizer, TiktokenTokenizerWrapper) assert tokenizer.model_name == tokenizer_kwargs['model_name'] else: - assert tokenizer.model_max_length == tokenizer_kwargs[ - 'model_max_length'] + assert tokenizer.model_max_length == tokenizer_kwargs['model_max_length' + ] assert isinstance(tokenizer, PreTrainedTokenizerBase) def test_tokenizer_no_EOS(): with pytest.raises( - ValueError, - match='The tokenizer bert-base-uncased must have an eos_token.'): + ValueError, + match='The tokenizer bert-base-uncased must have an eos_token.', + ): build_tokenizer('bert-base-uncased', {}) @@ -68,8 +77,11 @@ def test_build_generate_callback( interval_value: Union[str, int], ): - with mock.patch.object(Generate, '__init__', - autospec=True) as mock_generate: + with mock.patch.object( + Generate, + '__init__', + autospec=True, + ) as mock_generate: mock_generate.return_value = None build_callback( name='generate_callback', @@ -91,8 +103,11 @@ def test_build_generate_callback( def test_build_generate_callback_unspecified_interval(): with pytest.raises(TypeError): - with mock.patch.object(Generate, '__init__', - autospec=True) as mock_generate: + with mock.patch.object( + Generate, + '__init__', + autospec=True, + ) as mock_generate: mock_generate.return_value = None build_callback( name='generate_callback', @@ -105,8 +120,10 @@ def test_build_generate_callback_unspecified_interval(): def test_build_hf_checkpointer_callback(): - with mock.patch.object(HuggingFaceCheckpointer, - '__init__') as mock_hf_checkpointer: + with mock.patch.object( + HuggingFaceCheckpointer, + '__init__', + ) as mock_hf_checkpointer: mock_hf_checkpointer.return_value = None save_folder = 'path_to_save_folder' save_interval = 1 @@ -115,15 +132,15 @@ def test_build_hf_checkpointer_callback(): 'databricks_model_family': 'MptForCausalLM', 'databricks_model_size_parameters': '7b', 'databricks_model_source': 'mosaic-fine-tuning', - 'task': 'llm/v1/completions' - } + 'task': 'llm/v1/completions', + }, } build_callback( name='hf_checkpointer', kwargs={ 'save_folder': save_folder, 'save_interval': save_interval, - 'mlflow_logging_config': mlflow_logging_config_dict + 'mlflow_logging_config': mlflow_logging_config_dict, }, ) @@ -145,8 +162,8 @@ def test_build_logger(): 'init_kwargs': { 'config': { 'foo': 'bar', - } - } + }, + }, } wandb_logger = build_logger('wandb', logger_cfg) # type: ignore assert isinstance(wandb_logger, WandBLogger) @@ -173,44 +190,53 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore return self.linear1(self.norm0(self.linear0(x))) -@pytest.mark.parametrize('name, optimizer_config', [ - ('decoupled_adamw', {}), - ('decoupled_lionw', {}), - ('clip_lion', {}), - ('adalr_lion', {}), -]) -@pytest.mark.parametrize('opt_additional_config', [ - { - 'disable_grad': 'norm' - }, - { - 'disable_grad': ['norm', 'bias'] - }, - { - 'param_groups': [{ - 'param_str_match': 'norm', - 'lr': 1e-9, - 'weight_decay': 0.0, - },] - }, - { - 'param_groups': [{ - 'param_str_match': 'no.*.bias', - 'lr': 1e-9, - 'weight_decay': 0.0, - },] - }, - { - 'param_groups': [{ - 'param_str_match': 'norm', - 'lr': 1e-4, - 'weight_decay': 0.0, - },], - 'disable_grad': ['bias'], - }, -]) -def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], - opt_additional_config: Dict[str, Any]): +@pytest.mark.parametrize( + 'name, optimizer_config', + [ + ('decoupled_adamw', {}), + ('decoupled_lionw', {}), + ('clip_lion', {}), + ('adalr_lion', {}), + ], +) +@pytest.mark.parametrize( + 'opt_additional_config', + [ + { + 'disable_grad': 'norm', + }, + { + 'disable_grad': ['norm', 'bias'], + }, + { + 'param_groups': [{ + 'param_str_match': 'norm', + 'lr': 1e-9, + 'weight_decay': 0.0, + },], + }, + { + 'param_groups': [{ + 'param_str_match': 'no.*.bias', + 'lr': 1e-9, + 'weight_decay': 0.0, + },], + }, + { + 'param_groups': [{ + 'param_str_match': 'norm', + 'lr': 1e-4, + 'weight_decay': 0.0, + },], + 'disable_grad': ['bias'], + }, + ], +) +def test_build_optimizer( + name: str, + optimizer_config: Dict[str, Any], + opt_additional_config: Dict[str, Any], +): model = _DummyModule() optimizer_config = deepcopy(optimizer_config) optimizer_config.update(deepcopy(opt_additional_config)) @@ -227,8 +253,9 @@ def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], if 'param_groups' in opt_additional_config.keys(): for param_group_config, param_group in zip( - opt_additional_config['param_groups'], - optimizer.param_groups[1:]): + opt_additional_config['param_groups'], + optimizer.param_groups[1:], + ): param_group_config = deepcopy(param_group_config) param_str_match = param_group_config.pop('param_str_match') @@ -249,7 +276,8 @@ def test_build_evaluators_empty(): tokenizer=None, # type: ignore device_eval_batch_size=1, icl_seq_len=2, - icl_subset_num_batches=3) + icl_subset_num_batches=3, + ) assert evaluators == [] assert logger_keys == [] assert eval_gauntlet_callback is None @@ -266,8 +294,11 @@ def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch): 'drop_last': False, 'num_workers': 8, }) - monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', - lambda *args, **kwargs: MagicMock()) + monkeypatch.setattr( + 'llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, + **kwargs: MagicMock(), + ) eval_loaders = build_eval_loaders(eval_loader_cfg, tokenizer, 2) assert len(eval_loaders) == 1 @@ -294,10 +325,13 @@ def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch): }, 'drop_last': False, 'num_workers': 8, - } + }, ]) - monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', - lambda *args, **kwargs: MagicMock()) + monkeypatch.setattr( + 'llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, + **kwargs: MagicMock(), + ) eval_loaders2 = build_eval_loaders(multi_eval_loader_cfg, tokenizer, 2) assert len(eval_loaders2) == 2 @@ -330,7 +364,7 @@ def test_add_metrics_to_eval_loaders(): metric_names=['c'], dataloader=None, # type: ignore device_eval_microbatch_size=1, - ) + ), ] new_evaluators = add_metrics_to_eval_loaders(evaluators, ['new1', 'new2']) diff --git a/tests/utils/test_huggingface_hub_utils.py b/tests/utils/test_huggingface_hub_utils.py index 39dbf2781d..ffb20a909a 100644 --- a/tests/utils/test_huggingface_hub_utils.py +++ b/tests/utils/test_huggingface_hub_utils.py @@ -5,8 +5,10 @@ import pytest -from llmfoundry.utils.huggingface_hub_utils import (_flatten_import, - _remove_import) +from llmfoundry.utils.huggingface_hub_utils import ( + _flatten_import, + _remove_import, +) def test_flatten_import_true(): @@ -19,8 +21,10 @@ def test_flatten_import_false(): assert not _flatten_import(node, ('x', 'z')) -@pytest.mark.parametrize('prefix_to_remove,expected_imports_remaining', - [('llmfoundry', 1), ('llmfoundry.utils', 2)]) +@pytest.mark.parametrize( + 'prefix_to_remove,expected_imports_remaining', + [('llmfoundry', 1), ('llmfoundry.utils', 2)], +) def test_remove_imports(prefix_to_remove: str, expected_imports_remaining: int): source_code = """ from llmfoundry import a @@ -33,8 +37,10 @@ def test_remove_imports(prefix_to_remove: str, expected_imports_remaining: int): imports_kept = 0 for node in ast.walk(tree): - if isinstance(node, ast.ImportFrom) and not _remove_import( - node, [prefix_to_remove]): + if isinstance( + node, + ast.ImportFrom, + ) and not _remove_import(node, [prefix_to_remove]): imports_kept += 1 assert imports_kept == expected_imports_remaining diff --git a/tests/utils/test_mlflow_logging.py b/tests/utils/test_mlflow_logging.py index b8dd0becdf..205c985e97 100644 --- a/tests/utils/test_mlflow_logging.py +++ b/tests/utils/test_mlflow_logging.py @@ -7,8 +7,10 @@ import pytest from omegaconf import OmegaConf -from llmfoundry.utils.config_utils import (_log_dataset_uri, - _parse_source_dataset) +from llmfoundry.utils.config_utils import ( + _log_dataset_uri, + _parse_source_dataset, +) mlflow = pytest.importorskip('mlflow') from mlflow.data.huggingface_dataset_source import HuggingFaceDatasetSource @@ -20,16 +22,20 @@ def create_config(**kwargs: Any): def test_parse_source_dataset_delta_table(): - cfg = create_config(source_dataset_train='db.schema.train_table', - source_dataset_eval='db.schema.eval_table') + cfg = create_config( + source_dataset_train='db.schema.train_table', + source_dataset_eval='db.schema.eval_table', + ) expected = [('delta_table', 'db.schema.train_table', 'train'), ('delta_table', 'db.schema.eval_table', 'eval')] assert _parse_source_dataset(cfg) == expected def test_parse_source_dataset_uc_volume(): - cfg = create_config(source_dataset_train='dbfs:/Volumes/train_data', - source_dataset_eval='dbfs:/Volumes/eval_data') + cfg = create_config( + source_dataset_train='dbfs:/Volumes/train_data', + source_dataset_eval='dbfs:/Volumes/eval_data', + ) expected = [('uc_volume', '/Volumes/train_data', 'train'), ('uc_volume', '/Volumes/eval_data', 'eval')] assert _parse_source_dataset(cfg) == expected @@ -42,25 +48,28 @@ def test_parse_source_dataset_hf(): }}, eval_loader={'dataset': { 'hf_name': 'huggingface/eval_dataset', - }}) + }}, + ) expected = [('hf', 'huggingface/train_dataset', 'train'), ('hf', 'huggingface/eval_dataset', 'eval')] assert _parse_source_dataset(cfg) == expected def test_parse_source_dataset_remote(): - cfg = create_config(train_loader={ - 'dataset': { - 'remote': 'https://remote/train_dataset', - 'split': 'train' - } - }, - eval_loader={ - 'dataset': { - 'remote': 'https://remote/eval_dataset', - 'split': 'eval' - } - }) + cfg = create_config( + train_loader={ + 'dataset': { + 'remote': 'https://remote/train_dataset', + 'split': 'train', + }, + }, + eval_loader={ + 'dataset': { + 'remote': 'https://remote/eval_dataset', + 'split': 'eval', + }, + }, + ) expected = [('https', 'https://remote/train_dataset/train/', 'train'), ('https', 'https://remote/eval_dataset/eval/', 'eval')] assert _parse_source_dataset(cfg) == expected @@ -69,13 +78,14 @@ def test_parse_source_dataset_remote(): def test_log_dataset_uri(): cfg = create_config( train_loader={'dataset': { - 'hf_name': 'huggingface/train_dataset' + 'hf_name': 'huggingface/train_dataset', }}, eval_loader={'dataset': { - 'hf_name': 'huggingface/eval_dataset' + 'hf_name': 'huggingface/eval_dataset', }}, source_dataset_train='huggingface/train_dataset', - source_dataset_eval='huggingface/eval_dataset') + source_dataset_eval='huggingface/eval_dataset', + ) with patch('mlflow.log_input') as mock_log_input: _log_dataset_uri(cfg) @@ -85,12 +95,15 @@ def test_log_dataset_uri(): ] assert all( isinstance(call.source, HuggingFaceDatasetSource) - for call in meta_dataset_calls), 'Source types are incorrect' + for call in meta_dataset_calls + ), 'Source types are incorrect' # Verify the names assert meta_dataset_calls[ - 0].name == 'train', f"Expected 'train', got {meta_dataset_calls[0].name}" + 0 + ].name == 'train', f"Expected 'train', got {meta_dataset_calls[0].name}" assert meta_dataset_calls[ - 1].name == 'eval', f"Expected 'eval', got {meta_dataset_calls[1].name}" + 1 + ].name == 'eval', f"Expected 'eval', got {meta_dataset_calls[1].name}" def test_multiple_eval_datasets(): @@ -109,7 +122,7 @@ def test_multiple_eval_datasets(): 'dataset': { 'hf_name': 'huggingface/eval_dataset2', }, - }] + }], }) expected_data_paths = [('hf', 'huggingface/train_dataset', 'train'), @@ -121,7 +134,8 @@ def test_multiple_eval_datasets(): mock_meta_dataset.side_effect = lambda source, name: MagicMock() data_paths = _parse_source_dataset(cfg) assert sorted(data_paths) == sorted( - expected_data_paths), 'Data paths did not match expected' + expected_data_paths, + ), 'Data paths did not match expected' @pytest.fixture diff --git a/tests/utils/test_model_download_utils.py b/tests/utils/test_model_download_utils.py index 08e11a3b0e..8519277e74 100644 --- a/tests/utils/test_model_download_utils.py +++ b/tests/utils/test_model_download_utils.py @@ -17,8 +17,12 @@ from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME from llmfoundry.utils.model_download_utils import ( - DEFAULT_IGNORE_PATTERNS, PYTORCH_WEIGHTS_PATTERN, SAFE_WEIGHTS_PATTERN, - download_from_hf_hub, download_from_http_fileserver) + DEFAULT_IGNORE_PATTERNS, + PYTORCH_WEIGHTS_PATTERN, + SAFE_WEIGHTS_PATTERN, + download_from_hf_hub, + download_from_http_fileserver, +) # ======================== download_from_hf_hub tests ======================== @@ -95,25 +99,30 @@ ]) @mock.patch('huggingface_hub.snapshot_download') @mock.patch('huggingface_hub.list_repo_files') -def test_download_from_hf_hub_weights_pref(mock_list_repo_files: MagicMock, - mock_snapshot_download: MagicMock, - prefer_safetensors: bool, - repo_files: List[str], - expected_ignore_patterns: List[str]): +def test_download_from_hf_hub_weights_pref( + mock_list_repo_files: MagicMock, + mock_snapshot_download: MagicMock, + prefer_safetensors: bool, + repo_files: List[str], + expected_ignore_patterns: List[str], +): test_repo_id = 'test_repo_id' save_dir = 'save_dir' mock_list_repo_files.return_value = repo_files - download_from_hf_hub(test_repo_id, - save_dir=save_dir, - prefer_safetensors=prefer_safetensors) + download_from_hf_hub( + test_repo_id, + save_dir=save_dir, + prefer_safetensors=prefer_safetensors, + ) mock_snapshot_download.assert_called_once_with( test_repo_id, local_dir=save_dir, local_dir_use_symlinks=False, allow_patterns=None, ignore_patterns=expected_ignore_patterns, - token=None) + token=None, + ) @mock.patch('huggingface_hub.snapshot_download') @@ -185,9 +194,11 @@ def test_download_from_hf_hub_retry( @mock.patch.object(requests.Session, 'get') @mock.patch('os.makedirs') @mock.patch('builtins.open') -def test_download_from_http_fileserver(mock_open: MagicMock, - mock_makedirs: MagicMock, - mock_get: MagicMock): +def test_download_from_http_fileserver( + mock_open: MagicMock, + mock_makedirs: MagicMock, + mock_get: MagicMock, +): model_url = f'https://cache.com/models/model/' save_dir = 'save_dir/' diff --git a/tests/utils/test_prompt_files.py b/tests/utils/test_prompt_files.py index 12a5d02999..d77360dd75 100644 --- a/tests/utils/test_prompt_files.py +++ b/tests/utils/test_prompt_files.py @@ -13,6 +13,10 @@ def test_load_prompt_strings(tmp_path: Path): f.write('hello goodbye') temp = utils.PROMPTFILE_PREFIX + str(tmp_path / 'prompts.txt') - assert utils.load_prompts( - [temp, temp, 'why'], - ' ') == ['hello', 'goodbye', 'hello', 'goodbye', 'why'] + assert utils.load_prompts([temp, temp, 'why'], ' ') == [ + 'hello', + 'goodbye', + 'hello', + 'goodbye', + 'why', + ]