diff --git a/llmfoundry/cli/cli.py b/llmfoundry/cli/cli.py index 25c1a6d230..8e86e76467 100644 --- a/llmfoundry/cli/cli.py +++ b/llmfoundry/cli/cli.py @@ -1,12 +1,29 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from typing import Optional + import typer from llmfoundry.cli import registry_cli +from llmfoundry.train import train_from_yaml app = typer.Typer(pretty_exceptions_show_locals=False) app.add_typer(registry_cli.app, name='registry') + +@app.command(name='train') +def train( + yaml_path: str = typer.Argument( + ..., + help='Path to the YAML configuration file', + ), # type: ignore + args_list: Optional[list[str]] = typer. + Argument(None, help='Additional command line arguments'), # type: ignore +): + """Run the training with optional overrides from CLI.""" + train_from_yaml(yaml_path, args_list) + + if __name__ == '__main__': app() diff --git a/llmfoundry/train/__init__.py b/llmfoundry/train/__init__.py new file mode 100644 index 0000000000..8a4c2749db --- /dev/null +++ b/llmfoundry/train/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +from llmfoundry.train.train import ( + TRAIN_CONFIG_KEYS, + TrainConfig, + train, + train_from_yaml, + validate_config, +) + +__all__ = [ + 'train', + 'train_from_yaml', + 'TrainConfig', + 'TRAIN_CONFIG_KEYS', + 'validate_config', +] diff --git a/llmfoundry/train/train.py b/llmfoundry/train/train.py new file mode 100644 index 0000000000..273372e1cd --- /dev/null +++ b/llmfoundry/train/train.py @@ -0,0 +1,570 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +import gc +import logging +import os +import time +import warnings +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed +from composer import ComposerModel, Trainer +from composer.core.callback import Callback +from composer.profiler import ( + JSONTraceHandler, + Profiler, + TraceHandler, + cyclic_schedule, +) +from composer.utils import dist, get_device, reproducibility +from omegaconf import DictConfig +from omegaconf import OmegaConf as om + +from llmfoundry.callbacks import AsyncEval, HuggingFaceCheckpointer +from llmfoundry.data.dataloader import build_dataloader +from llmfoundry.eval.metrics.nlp import InContextLearningMetric +from llmfoundry.layers_registry import ffns_with_megablocks +from llmfoundry.utils import ( + find_mosaicml_logger, + log_train_analytics, + maybe_create_mosaicml_logger, +) +from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, + build_algorithm, + build_callback, + build_composer_model, + build_evaluators, + build_logger, + build_optimizer, + build_scheduler, + build_tokenizer, +) +from llmfoundry.utils.config_utils import ( + TRAIN_CONFIG_KEYS, + TrainConfig, + log_config, + log_dataset_uri, + make_dataclass_and_log_config, + pop_config, + process_init_device, +) +from llmfoundry.utils.exceptions import ( + BaseContextualError, + EvalDataLoaderLocation, + TrainDataLoaderLocation, +) +from llmfoundry.utils.registry_utils import import_file + +log = logging.getLogger(__name__) + + +def validate_config(train_config: TrainConfig): + """Validates compatible model and dataloader selection.""" + # Validate the rest of the config + loaders = [train_config.train_loader] + if train_config.eval_loaders is not None: + for loader in (train_config.eval_loaders or []): # pyright + if 'label' not in loader or loader['label'] is None: + raise ValueError( + 'When specifying multiple evaluation datasets, each one must include the \ + `label` attribute.', + ) + loaders.append(loader) + if train_config.eval_loader is not None: + loaders.append(train_config.eval_loader) + for loader in loaders: + if loader['name'] == 'text': + if train_config.model['name'] == 'hf_t5': + raise ValueError( + f'Model type "{train_config.model["name"]}" is not supported when using the "text " ' +\ + f'dataloader. Only finetuning is supported.') + + if train_config.icl_tasks is not None or train_config.icl_tasks_str is not None: + if train_config.model['name'] == 'hf_t5': + raise ValueError( + 'ICL evaluation does not currently support Encoder-Decoder models, such as "hf_t5".', + ) + + if ( + train_config.model.get('fc_type', 'torch') != 'te' and + 'te' not in train_config.model.get('ffn_config', + {}).get('ffn_type', 'mptmlp') and + 'fp8' in train_config.precision + ): + warnings.warn( + "fp8 only supported for te.Linear layers. Either set `cfg.model.fc_typ='te'` or " + + + "`cfg.model.ffn_config.ffn_type='te_ln_mlp'` to enable layers using fp8 precision.", + ) + + if ( + train_config.model.get('fc_type', 'torch') == 'te' or 'te' + in train_config.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') + ): + fsdp_config = train_config.fsdp_config + act_ckpt = fsdp_config.get( + 'activation_checkpointing', + False, + ) if fsdp_config else False + act_ckpt_reentrant = fsdp_config.get( + 'activation_checkpointing_reentrant', + False, + ) if fsdp_config else False + if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True: + warnings.warn( + '`te.Linear` layers do not support activation_checkpointing with ' + + '`activation_checkpointing_reentrant = True`. ' + + 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.', + ) + assert train_config.fsdp_config is not None # pyright (this is known because fsdp_config is not None) + train_config.fsdp_config['activation_checkpointing_reentrant' + ] = False + + if train_config.model.get('ffn_config', + {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': + warnings.warn( + '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + + 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.', + ) + torch._dynamo.config.suppress_errors = True # type: ignore (third-party) + + if train_config.model.get('load_in_8bit', False): + raise ValueError( + '`load_in_8bit` is only supported for evaluation rather than training.', + ) + + if train_config.model.get('ffn_config', {}).get( + 'ffn_type', + 'mptmlp', + ) in ffns_with_megablocks: + moe_world_size = train_config.model.get('ffn_config', + {}).get('moe_world_size', 1) + use_orig_params = train_config.fsdp_config.get( + 'use_orig_params', + True, + ) if train_config.fsdp_config is not None else True + if moe_world_size > 1 and not use_orig_params: + raise ValueError( + f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.', + ) + + +def _log_num_params(model: ComposerModel, logged_cfg: Dict[str, Any]): + # Log number of parameters + if hasattr(model, 'n_total_params'): + n_params = model.n_total_params + n_trainable_params = n_params # TODO: we currently assume all parameters are trainable. + else: + n_params = sum(p.numel() for p in model.parameters()) + n_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + if hasattr(model, 'n_active_params'): + n_active_params = model.n_active_params + else: + n_active_params = n_params + logged_cfg.update({ + 'n_params': n_params, + 'n_active_params': n_active_params, + 'n_trainable_params': n_trainable_params, + }) + + +def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): + """Initialize distributed and test setup with a barrier. + + Args: + dist_timeout (Union[int, float]): Timeout for initializing the process group + """ + log.debug('Initializing dist with device...') + dist.initialize_dist(get_device(None), timeout=dist_timeout) + log.debug('Testing barrier with device...') + dist.barrier() + log.debug('Barrier test passed with device.') + + +def train(cfg: DictConfig) -> Trainer: + code_paths = cfg.get('code_paths', []) + # Import any user provided code + for code_path in code_paths: + import_file(code_path) + + logged_cfg, train_cfg = make_dataclass_and_log_config( + cfg, + TrainConfig, + TRAIN_CONFIG_KEYS, + transforms='all', + ) + + # Set logging level + if train_cfg.python_log_level is not None: + logging.basicConfig( + # Example of format string + # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', + ) + logging.getLogger('llmfoundry').setLevel( + train_cfg.python_log_level.upper(), + ) # Foundry module + logging.getLogger(__name__).setLevel( + train_cfg.python_log_level.upper(), + ) # Train script + + _initialize_dist_with_barrier(dist_timeout=train_cfg.dist_timeout) + + # Filter deprecation warning from torch internal usage + warnings.filterwarnings( + action='ignore', + category=UserWarning, + message= + 'torch.distributed.*_base is a private function and will be deprecated.*', + ) + + # Check for incompatibilities between the model and data loaders + validate_config(train_cfg) + + cuda_alloc_conf = [] + # Get max split size mb + max_split_size_mb: Optional[int] = train_cfg.max_split_size_mb + if max_split_size_mb is not None: + cuda_alloc_conf.append(f'max_split_size_mb:{max_split_size_mb}') + + # Expandable segments + if train_cfg.expandable_segments: + cuda_alloc_conf.append('expandable_segments:True') + + if len(cuda_alloc_conf) > 0: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ','.join(cuda_alloc_conf) + + # Set CUDA lazy loading + # This can save a bit of memory if not all modules are needed + cuda_load_lazy: bool = train_cfg.cuda_load_lazy + if cuda_load_lazy: + os.environ['CUDA_MODULE_LOADING'] = 'LAZY' + + # Set seed first + seed: int = train_cfg.seed + reproducibility.seed_all(seed) + + # Mandatory model training configs + model_config = train_cfg.model + train_loader_config = train_cfg.train_loader + + # Optional fsdp data, fine-tuning, and eval configs + fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config + + eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders + icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str + eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str + + # Optional parameters will be set to default values if not specified. + default_run_name: str = os.environ.get('RUN_NAME', 'llm') + run_name: str = train_cfg.run_name if train_cfg.run_name else default_run_name + is_state_dict_sharded: bool = ( + fsdp_config.get('state_dict_type', 'full') == 'sharded' + ) if fsdp_config else False + save_latest_filename: str = train_cfg.save_latest_filename if train_cfg.save_latest_filename else 'latest-sharded-rank{rank}' if is_state_dict_sharded else 'latest-rank{rank}.pt' + save_filename: str = train_cfg.save_filename if train_cfg.save_filename else 'ep{epoch}-ba{batch}-rank{rank}.pt' + + # Enable autoresume from model checkpoints if possible + autoresume_default: bool = False + if logged_cfg.get('run_name', None) is not None \ + and train_cfg.save_folder is not None \ + and not train_cfg.save_overwrite \ + and not train_cfg.save_weights_only: + autoresume_default = True + + if not train_cfg.autoresume and autoresume_default: + log.info( + 'As run_name, save_folder, and save_latest_filename are set, \ + changing autoresume default to True...', + ) + + # Warn if fsdp is enabled but user only has 1 GPU + if dist.get_world_size() == 1 and fsdp_config is not None: + warnings.warn( + 'FSDP is not applicable for single-GPU training. Reverting to DDP.', + ) + fsdp_config = None + + # Initialize context + init_context = process_init_device(model_config, fsdp_config) + logged_cfg.update({'fsdp_config': fsdp_config}, merge=True) + + # Build tokenizer + log.info('Building tokenizer...') + tokenizer_name = train_cfg.tokenizer['name'] + tokenizer_kwargs = train_cfg.tokenizer.get('kwargs', {}) + tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) + + # Scheduler + scheduler_name: str = train_cfg.scheduler.pop('name') + scheduler = build_scheduler(scheduler_name, train_cfg.scheduler) + + # Loggers + loggers = [ + build_logger(str(name), logger_cfg) + for name, logger_cfg in train_cfg.loggers.items() + ] if train_cfg.loggers else [] + + mosaicml_logger = find_mosaicml_logger(loggers) + if mosaicml_logger is None: + mosaicml_logger = maybe_create_mosaicml_logger() + if mosaicml_logger is not None: + # mosaicml_logger will be None if run isn't on MosaicML platform + loggers.append(mosaicml_logger) + + if train_cfg.metadata is not None: + # Optionally flatten the metadata for logging + if train_cfg.flatten_metadata: + logged_cfg.pop('metadata', None) + common_keys = set( + logged_cfg.keys(), + ) & set(train_cfg.metadata.keys()) + if len(common_keys) > 0: + raise ValueError( + f'Keys {common_keys} are already present in the config. Please rename them in metadata ' + + + 'or set flatten_metadata=False to avoid flattening the metadata in the logged config.', + ) + + logged_cfg.update(train_cfg.metadata, merge=True) + + if mosaicml_logger is not None: + mosaicml_logger.log_metrics(train_cfg.metadata) + mosaicml_logger._flush_metadata(force_flush=True) + + # Profiling + profiler: Optional[Profiler] = None + profiler_cfg = train_cfg.profiler + if profiler_cfg: + profiler_schedule_cfg: Dict = pop_config( + profiler_cfg, + 'schedule', + must_exist=True, + ) + profiler_schedule = cyclic_schedule(**profiler_schedule_cfg) + # Only support json trace handler + profiler_trace_handlers: List[TraceHandler] = [] + profiler_trace_cfg: Optional[Dict] = pop_config( + profiler_cfg, + 'json_trace_handler', + must_exist=False, + default_value=None, + ) + if profiler_trace_cfg: + profiler_trace_handlers.append( + JSONTraceHandler(**profiler_trace_cfg), + ) + profiler = Profiler( + **profiler_cfg, + trace_handlers=profiler_trace_handlers, + schedule=profiler_schedule, + ) + + callback_configs = train_cfg.callbacks or {} + + # Callbacks + callbacks: List[Callback] = [ + build_callback( + name=str(name), + kwargs=callback_cfg, + train_config=logged_cfg, + ) for name, callback_cfg in callback_configs.items() + ] + + use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks) + + algorithm_configs = train_cfg.algorithms or {} + + # Algorithms + algorithms = [ + build_algorithm(str(name), algorithm_cfg) + for name, algorithm_cfg in algorithm_configs.items() + ] + + # Dataloaders + log.info('Building train loader...') + try: + train_loader = build_dataloader( + train_loader_config, + tokenizer, + train_cfg.device_train_batch_size, + ) + except BaseContextualError as e: + e.location = TrainDataLoaderLocation + raise e + + if mosaicml_logger is not None: + mosaicml_logger.log_metrics({'data_validated': time.time()}) + + ## Evaluation + if use_async_eval: + evaluators = [] + if train_cfg.eval_first: + warnings.warn( + 'AsyncEval callback does not support eval_first=True. Ignoring.', + ) + train_cfg.eval_first = False + + else: + try: + log.info('Building eval loader...') + eval_icl_seq_len: int = train_cfg.icl_seq_len if train_cfg.icl_seq_len else train_cfg.max_seq_len + evaluators, _, eval_gauntlet_callback = build_evaluators( + eval_loader_config, + icl_tasks_config, + eval_gauntlet_config, + tokenizer=tokenizer, + device_eval_batch_size=train_cfg.device_eval_batch_size, + icl_seq_len=eval_icl_seq_len, + icl_subset_num_batches=train_cfg.icl_subset_num_batches, + ) + if eval_gauntlet_callback is not None: + callbacks.append(eval_gauntlet_callback) + except BaseContextualError as e: + e.location = EvalDataLoaderLocation + raise e + + if mosaicml_logger is not None: + log_train_analytics( + mosaicml_logger, + model_config, + train_loader_config, + eval_loader_config, + train_cfg.callbacks, + tokenizer_name, + train_cfg.load_path, + icl_tasks_config, + eval_gauntlet_config, + ) + # Build Model + log.info('Initializing model...') + name = model_config.pop('name') + assert isinstance(name, str) + assert isinstance(model_config, dict) + model = build_composer_model( + name=name, + tokenizer=tokenizer, + init_context=init_context, + master_weights_dtype=model_config.get('master_weights_dtype', None), + cfg=model_config, + ) + + _log_num_params(model, logged_cfg) + + # Optimizer + optimizer_name: str = train_cfg.optimizer.pop('name') + optimizer_cfg = train_cfg.optimizer + optimizer = build_optimizer(model, optimizer_name, optimizer_cfg) + + # Now add the eval metrics + try: + if eval_loader_config is not None and not use_async_eval: + eval_metrics = model.get_metrics(is_train=False) + non_icl_metrics = [ + metric_name for metric_name, metric in eval_metrics.items() + if not isinstance(metric, InContextLearningMetric) + ] + evaluators = add_metrics_to_eval_loaders( + evaluators, + non_icl_metrics, + ) + except BaseContextualError as e: + e.location = EvalDataLoaderLocation + raise e + + compile_config = train_cfg.compile_config + # Build the Trainer + log.info('Building trainer...') + trainer = Trainer( + run_name=run_name, + seed=seed, + model=model, + train_dataloader=train_loader, + eval_dataloader=evaluators, + optimizers=optimizer, + schedulers=scheduler, + max_duration=train_cfg.max_duration, + eval_interval=train_cfg.eval_interval, + eval_subset_num_batches=train_cfg.eval_subset_num_batches, + progress_bar=train_cfg.progress_bar, + log_to_console=train_cfg.log_to_console, + console_log_interval=train_cfg.console_log_interval, + loggers=loggers, + callbacks=callbacks, + precision=train_cfg.precision, + algorithms=algorithms, + device_train_microbatch_size=train_cfg.device_train_microbatch_size, + parallelism_config={'fsdp': fsdp_config}, + save_folder=train_cfg.save_folder, + save_filename=save_filename, + save_latest_filename=save_latest_filename, + save_interval=train_cfg.save_interval, + save_num_checkpoints_to_keep=train_cfg.save_num_checkpoints_to_keep, + save_overwrite=train_cfg.save_overwrite, + save_weights_only=train_cfg.save_weights_only, + load_path=train_cfg.load_path, + load_weights_only=train_cfg.load_weights_only, + load_strict_model_weights=train_cfg.load_strict_model_weights, + load_ignore_keys=train_cfg.load_ignore_keys, + save_ignore_keys=train_cfg.save_ignore_keys, + autoresume=train_cfg.autoresume, + python_log_level=train_cfg.python_log_level, + dist_timeout=train_cfg.dist_timeout, + profiler=profiler, + compile_config=compile_config, + ) + + # Optionally just save an HF checkpoint + if train_cfg.only_hf_checkpoint: + hf_checkpointer_callbacks = [ + c for c in callbacks if isinstance(c, HuggingFaceCheckpointer) + ] + if len(hf_checkpointer_callbacks) == 0: + raise ValueError( + 'No HuggingFaceCheckpointer callback found, but only_hf_checkpoint was set to True. Please add a HuggingFaceCheckpointer.', + ) + if len(hf_checkpointer_callbacks) > 1: + raise ValueError( + 'Multiple HuggingFaceCheckpointer callbacks found, but only_hf_checkpoint was set to True. Please remove all but one HuggingFaceCheckpointer.', + ) + + hf_checkpointer_callback = hf_checkpointer_callbacks[0] + hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger) + return trainer + + if train_cfg.log_config: + log.info('Logging config') + log_config(logged_cfg) + log_dataset_uri(logged_cfg) + torch.cuda.empty_cache() + gc.collect() + + # Eval first if requested + if train_cfg.eval_first and trainer.state.timestamp.batch.value == 0: + trainer.eval() + + log.info('Starting training...') + trainer.fit() + + log.info('Done.') + return trainer + + +def train_from_yaml( + yaml_path: str, + args_list: Optional[List[str]] = None, +) -> Trainer: + """Run the training with optional overrides from CLI.""" + # Load yaml and CLI arguments. + with open(yaml_path) as f: + yaml_cfg = om.load(f) + if args_list: + cli_cfg = om.from_cli(args_list) + yaml_cfg = om.merge(yaml_cfg, cli_cfg) + assert isinstance(yaml_cfg, DictConfig) + return train(yaml_cfg) diff --git a/scripts/train/train.py b/scripts/train/train.py index 668e20dbab..3c8973048b 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -1,571 +1,9 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import gc -import logging -import os import sys -import time -import warnings -from typing import Any, Dict, List, Optional, Union - -import torch -import torch.distributed -from composer import ComposerModel, Trainer -from composer.core.callback import Callback -from composer.profiler import ( - JSONTraceHandler, - Profiler, - TraceHandler, - cyclic_schedule, -) -from composer.utils import dist, get_device, reproducibility -from omegaconf import DictConfig -from omegaconf import OmegaConf as om - -from llmfoundry.callbacks import AsyncEval, HuggingFaceCheckpointer -from llmfoundry.data.dataloader import build_dataloader -from llmfoundry.eval.metrics.nlp import InContextLearningMetric -from llmfoundry.layers_registry import ffns_with_megablocks -from llmfoundry.utils import ( - find_mosaicml_logger, - log_train_analytics, - maybe_create_mosaicml_logger, -) -from llmfoundry.utils.builders import ( - add_metrics_to_eval_loaders, - build_algorithm, - build_callback, - build_composer_model, - build_evaluators, - build_logger, - build_optimizer, - build_scheduler, - build_tokenizer, -) -from llmfoundry.utils.config_utils import ( - TRAIN_CONFIG_KEYS, - TrainConfig, - log_config, - log_dataset_uri, - make_dataclass_and_log_config, - pop_config, - process_init_device, -) -from llmfoundry.utils.exceptions import ( - BaseContextualError, - EvalDataLoaderLocation, - TrainDataLoaderLocation, -) -from llmfoundry.utils.registry_utils import import_file - -log = logging.getLogger(__name__) - - -def validate_config(train_config: TrainConfig): - """Validates compatible model and dataloader selection.""" - # Validate the rest of the config - loaders = [train_config.train_loader] - if train_config.eval_loaders is not None: - for loader in (train_config.eval_loaders or []): # pyright - if 'label' not in loader or loader['label'] is None: - raise ValueError( - 'When specifying multiple evaluation datasets, each one must include the \ - `label` attribute.', - ) - loaders.append(loader) - if train_config.eval_loader is not None: - loaders.append(train_config.eval_loader) - for loader in loaders: - if loader['name'] == 'text': - if train_config.model['name'] == 'hf_t5': - raise ValueError( - f'Model type "{train_config.model["name"]}" is not supported when using the "text " ' +\ - f'dataloader. Only finetuning is supported.') - - if train_config.icl_tasks is not None or train_config.icl_tasks_str is not None: - if train_config.model['name'] == 'hf_t5': - raise ValueError( - 'ICL evaluation does not currently support Encoder-Decoder models, such as "hf_t5".', - ) - - if ( - train_config.model.get('fc_type', 'torch') != 'te' and - 'te' not in train_config.model.get('ffn_config', - {}).get('ffn_type', 'mptmlp') and - 'fp8' in train_config.precision - ): - warnings.warn( - "fp8 only supported for te.Linear layers. Either set `cfg.model.fc_typ='te'` or " - + - "`cfg.model.ffn_config.ffn_type='te_ln_mlp'` to enable layers using fp8 precision.", - ) - - if ( - train_config.model.get('fc_type', 'torch') == 'te' or 'te' - in train_config.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') - ): - fsdp_config = train_config.fsdp_config - act_ckpt = fsdp_config.get( - 'activation_checkpointing', - False, - ) if fsdp_config else False - act_ckpt_reentrant = fsdp_config.get( - 'activation_checkpointing_reentrant', - False, - ) if fsdp_config else False - if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True: - warnings.warn( - '`te.Linear` layers do not support activation_checkpointing with ' - + '`activation_checkpointing_reentrant = True`. ' + - 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.', - ) - assert train_config.fsdp_config is not None # pyright (this is known because fsdp_config is not None) - train_config.fsdp_config['activation_checkpointing_reentrant' - ] = False - - if train_config.model.get('ffn_config', - {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': - warnings.warn( - '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + - 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.', - ) - torch._dynamo.config.suppress_errors = True # type: ignore (third-party) - - if train_config.model.get('load_in_8bit', False): - raise ValueError( - '`load_in_8bit` is only supported for evaluation rather than training.', - ) - - if train_config.model.get('ffn_config', {}).get( - 'ffn_type', - 'mptmlp', - ) in ffns_with_megablocks: - moe_world_size = train_config.model.get('ffn_config', - {}).get('moe_world_size', 1) - use_orig_params = train_config.fsdp_config.get( - 'use_orig_params', - True, - ) if train_config.fsdp_config is not None else True - if moe_world_size > 1 and not use_orig_params: - raise ValueError( - f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.', - ) - - -def _log_num_params(model: ComposerModel, logged_cfg: Dict[str, Any]): - # Log number of parameters - if hasattr(model, 'n_total_params'): - n_params = model.n_total_params - n_trainable_params = n_params # TODO: we currently assume all parameters are trainable. - else: - n_params = sum(p.numel() for p in model.parameters()) - n_trainable_params = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) - if hasattr(model, 'n_active_params'): - n_active_params = model.n_active_params - else: - n_active_params = n_params - logged_cfg.update({ - 'n_params': n_params, - 'n_active_params': n_active_params, - 'n_trainable_params': n_trainable_params, - }) - - -def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): - """Initialize distributed and test setup with a barrier. - - Args: - dist_timeout (Union[int, float]): Timeout for initializing the process group - """ - log.debug('Initializing dist with device...') - dist.initialize_dist(get_device(None), timeout=dist_timeout) - log.debug('Testing barrier with device...') - dist.barrier() - log.debug('Barrier test passed with device.') - - -def main(cfg: DictConfig) -> Trainer: - code_paths = cfg.get('code_paths', []) - # Import any user provided code - for code_path in code_paths: - import_file(code_path) - - logged_cfg, train_cfg = make_dataclass_and_log_config( - cfg, - TrainConfig, - TRAIN_CONFIG_KEYS, - transforms='all', - ) - - # Set logging level - if train_cfg.python_log_level is not None: - logging.basicConfig( - # Example of format string - # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here - format= - f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', - ) - logging.getLogger('llmfoundry').setLevel( - train_cfg.python_log_level.upper(), - ) # Foundry module - logging.getLogger(__name__).setLevel( - train_cfg.python_log_level.upper(), - ) # Train script - - _initialize_dist_with_barrier(dist_timeout=train_cfg.dist_timeout) - - # Filter deprecation warning from torch internal usage - warnings.filterwarnings( - action='ignore', - category=UserWarning, - message= - 'torch.distributed.*_base is a private function and will be deprecated.*', - ) - - # Check for incompatibilities between the model and data loaders - validate_config(train_cfg) - - cuda_alloc_conf = [] - # Get max split size mb - max_split_size_mb: Optional[int] = train_cfg.max_split_size_mb - if max_split_size_mb is not None: - cuda_alloc_conf.append(f'max_split_size_mb:{max_split_size_mb}') - - # Expandable segments - if train_cfg.expandable_segments: - cuda_alloc_conf.append('expandable_segments:True') - - if len(cuda_alloc_conf) > 0: - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ','.join(cuda_alloc_conf) - - # Set CUDA lazy loading - # This can save a bit of memory if not all modules are needed - cuda_load_lazy: bool = train_cfg.cuda_load_lazy - if cuda_load_lazy: - os.environ['CUDA_MODULE_LOADING'] = 'LAZY' - - # Set seed first - seed: int = train_cfg.seed - reproducibility.seed_all(seed) - - # Mandatory model training configs - model_config = train_cfg.model - train_loader_config = train_cfg.train_loader - - # Optional fsdp data, fine-tuning, and eval configs - fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config - - eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders - icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str - eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str - - # Optional parameters will be set to default values if not specified. - default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name: str = train_cfg.run_name if train_cfg.run_name else default_run_name - is_state_dict_sharded: bool = ( - fsdp_config.get('state_dict_type', 'full') == 'sharded' - ) if fsdp_config else False - save_latest_filename: str = train_cfg.save_latest_filename if train_cfg.save_latest_filename else 'latest-sharded-rank{rank}' if is_state_dict_sharded else 'latest-rank{rank}.pt' - save_filename: str = train_cfg.save_filename if train_cfg.save_filename else 'ep{epoch}-ba{batch}-rank{rank}.pt' - - # Enable autoresume from model checkpoints if possible - autoresume_default: bool = False - if logged_cfg.get('run_name', None) is not None \ - and train_cfg.save_folder is not None \ - and not train_cfg.save_overwrite \ - and not train_cfg.save_weights_only: - autoresume_default = True - - if not train_cfg.autoresume and autoresume_default: - log.info( - 'As run_name, save_folder, and save_latest_filename are set, \ - changing autoresume default to True...', - ) - - # Warn if fsdp is enabled but user only has 1 GPU - if dist.get_world_size() == 1 and fsdp_config is not None: - warnings.warn( - 'FSDP is not applicable for single-GPU training. Reverting to DDP.', - ) - fsdp_config = None - - # Initialize context - init_context = process_init_device(model_config, fsdp_config) - logged_cfg.update({'fsdp_config': fsdp_config}, merge=True) - - # Build tokenizer - log.info('Building tokenizer...') - tokenizer_name = train_cfg.tokenizer['name'] - tokenizer_kwargs = train_cfg.tokenizer.get('kwargs', {}) - tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - - # Scheduler - scheduler_name: str = train_cfg.scheduler.pop('name') - scheduler = build_scheduler(scheduler_name, train_cfg.scheduler) - - # Loggers - loggers = [ - build_logger(str(name), logger_cfg) - for name, logger_cfg in train_cfg.loggers.items() - ] if train_cfg.loggers else [] - - mosaicml_logger = find_mosaicml_logger(loggers) - if mosaicml_logger is None: - mosaicml_logger = maybe_create_mosaicml_logger() - if mosaicml_logger is not None: - # mosaicml_logger will be None if run isn't on MosaicML platform - loggers.append(mosaicml_logger) - - if train_cfg.metadata is not None: - # Optionally flatten the metadata for logging - if train_cfg.flatten_metadata: - logged_cfg.pop('metadata', None) - common_keys = set( - logged_cfg.keys(), - ) & set(train_cfg.metadata.keys()) - if len(common_keys) > 0: - raise ValueError( - f'Keys {common_keys} are already present in the config. Please rename them in metadata ' - + - 'or set flatten_metadata=False to avoid flattening the metadata in the logged config.', - ) - - logged_cfg.update(train_cfg.metadata, merge=True) - - if mosaicml_logger is not None: - mosaicml_logger.log_metrics(train_cfg.metadata) - mosaicml_logger._flush_metadata(force_flush=True) - - # Profiling - profiler: Optional[Profiler] = None - profiler_cfg = train_cfg.profiler - if profiler_cfg: - profiler_schedule_cfg: Dict = pop_config( - profiler_cfg, - 'schedule', - must_exist=True, - ) - profiler_schedule = cyclic_schedule(**profiler_schedule_cfg) - # Only support json trace handler - profiler_trace_handlers: List[TraceHandler] = [] - profiler_trace_cfg: Optional[Dict] = pop_config( - profiler_cfg, - 'json_trace_handler', - must_exist=False, - default_value=None, - ) - if profiler_trace_cfg: - profiler_trace_handlers.append( - JSONTraceHandler(**profiler_trace_cfg), - ) - profiler = Profiler( - **profiler_cfg, - trace_handlers=profiler_trace_handlers, - schedule=profiler_schedule, - ) - - callback_configs = train_cfg.callbacks or {} - - # Callbacks - callbacks: List[Callback] = [ - build_callback( - name=str(name), - kwargs=callback_cfg, - train_config=logged_cfg, - ) for name, callback_cfg in callback_configs.items() - ] - - use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks) - - algorithm_configs = train_cfg.algorithms or {} - - # Algorithms - algorithms = [ - build_algorithm(str(name), algorithm_cfg) - for name, algorithm_cfg in algorithm_configs.items() - ] - - # Dataloaders - log.info('Building train loader...') - try: - train_loader = build_dataloader( - train_loader_config, - tokenizer, - train_cfg.device_train_batch_size, - ) - except BaseContextualError as e: - e.location = TrainDataLoaderLocation - raise e - - if mosaicml_logger is not None: - mosaicml_logger.log_metrics({'data_validated': time.time()}) - - ## Evaluation - if use_async_eval: - evaluators = [] - if train_cfg.eval_first: - warnings.warn( - 'AsyncEval callback does not support eval_first=True. Ignoring.', - ) - train_cfg.eval_first = False - - else: - try: - log.info('Building eval loader...') - eval_icl_seq_len: int = train_cfg.icl_seq_len if train_cfg.icl_seq_len else train_cfg.max_seq_len - evaluators, _, eval_gauntlet_callback = build_evaluators( - eval_loader_config, - icl_tasks_config, - eval_gauntlet_config, - tokenizer=tokenizer, - device_eval_batch_size=train_cfg.device_eval_batch_size, - icl_seq_len=eval_icl_seq_len, - icl_subset_num_batches=train_cfg.icl_subset_num_batches, - ) - if eval_gauntlet_callback is not None: - callbacks.append(eval_gauntlet_callback) - except BaseContextualError as e: - e.location = EvalDataLoaderLocation - raise e - - if mosaicml_logger is not None: - log_train_analytics( - mosaicml_logger, - model_config, - train_loader_config, - eval_loader_config, - train_cfg.callbacks, - tokenizer_name, - train_cfg.load_path, - icl_tasks_config, - eval_gauntlet_config, - ) - # Build Model - log.info('Initializing model...') - name = model_config.pop('name') - assert isinstance(name, str) - assert isinstance(model_config, dict) - model = build_composer_model( - name=name, - tokenizer=tokenizer, - init_context=init_context, - master_weights_dtype=model_config.get('master_weights_dtype', None), - cfg=model_config, - ) - - _log_num_params(model, logged_cfg) - - # Optimizer - optimizer_name: str = train_cfg.optimizer.pop('name') - optimizer_cfg = train_cfg.optimizer - optimizer = build_optimizer(model, optimizer_name, optimizer_cfg) - - # Now add the eval metrics - try: - if eval_loader_config is not None and not use_async_eval: - eval_metrics = model.get_metrics(is_train=False) - non_icl_metrics = [ - metric_name for metric_name, metric in eval_metrics.items() - if not isinstance(metric, InContextLearningMetric) - ] - evaluators = add_metrics_to_eval_loaders( - evaluators, - non_icl_metrics, - ) - except BaseContextualError as e: - e.location = EvalDataLoaderLocation - raise e - - compile_config = train_cfg.compile_config - # Build the Trainer - log.info('Building trainer...') - trainer = Trainer( - run_name=run_name, - seed=seed, - model=model, - train_dataloader=train_loader, - eval_dataloader=evaluators, - optimizers=optimizer, - schedulers=scheduler, - max_duration=train_cfg.max_duration, - eval_interval=train_cfg.eval_interval, - eval_subset_num_batches=train_cfg.eval_subset_num_batches, - progress_bar=train_cfg.progress_bar, - log_to_console=train_cfg.log_to_console, - console_log_interval=train_cfg.console_log_interval, - loggers=loggers, - callbacks=callbacks, - precision=train_cfg.precision, - algorithms=algorithms, - device_train_microbatch_size=train_cfg.device_train_microbatch_size, - parallelism_config={'fsdp': fsdp_config}, - save_folder=train_cfg.save_folder, - save_filename=save_filename, - save_latest_filename=save_latest_filename, - save_interval=train_cfg.save_interval, - save_num_checkpoints_to_keep=train_cfg.save_num_checkpoints_to_keep, - save_overwrite=train_cfg.save_overwrite, - save_weights_only=train_cfg.save_weights_only, - load_path=train_cfg.load_path, - load_weights_only=train_cfg.load_weights_only, - load_strict_model_weights=train_cfg.load_strict_model_weights, - load_ignore_keys=train_cfg.load_ignore_keys, - save_ignore_keys=train_cfg.save_ignore_keys, - autoresume=train_cfg.autoresume, - python_log_level=train_cfg.python_log_level, - dist_timeout=train_cfg.dist_timeout, - profiler=profiler, - compile_config=compile_config, - ) - - # Optionally just save an HF checkpoint - if train_cfg.only_hf_checkpoint: - hf_checkpointer_callbacks = [ - c for c in callbacks if isinstance(c, HuggingFaceCheckpointer) - ] - if len(hf_checkpointer_callbacks) == 0: - raise ValueError( - 'No HuggingFaceCheckpointer callback found, but only_hf_checkpoint was set to True. Please add a HuggingFaceCheckpointer.', - ) - if len(hf_checkpointer_callbacks) > 1: - raise ValueError( - 'Multiple HuggingFaceCheckpointer callbacks found, but only_hf_checkpoint was set to True. Please remove all but one HuggingFaceCheckpointer.', - ) - - hf_checkpointer_callback = hf_checkpointer_callbacks[0] - hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger) - return trainer - - if train_cfg.log_config: - log.info('Logging config') - log_config(logged_cfg) - log_dataset_uri(logged_cfg) - torch.cuda.empty_cache() - gc.collect() - - # Eval first if requested - if train_cfg.eval_first and trainer.state.timestamp.batch.value == 0: - trainer.eval() - - log.info('Starting training...') - trainer.fit() - - log.info('Done.') - return trainer +from llmfoundry.train import train_from_yaml if __name__ == '__main__': yaml_path, args_list = sys.argv[1], sys.argv[2:] - - # Disable resolving environment variables through omegaconf. - om.clear_resolver('oc.env') - - # Load yaml and cli arguments. - with open(yaml_path) as f: - yaml_cfg = om.load(f) - cli_cfg = om.from_cli(args_list) - cfg = om.merge(yaml_cfg, cli_cfg) - assert isinstance(cfg, DictConfig) - main(cfg) + train_from_yaml(yaml_path, args_list) diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index 2be1d5139d..a49f1ac07a 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -11,12 +11,12 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from llmfoundry.train import TrainConfig # noqa: E402 +from llmfoundry.train import TRAIN_CONFIG_KEYS, train, validate_config from llmfoundry.utils.config_utils import ( make_dataclass_and_log_config, update_batch_size_info, ) -from scripts.train.train import TrainConfig # noqa: E402 -from scripts.train.train import TRAIN_CONFIG_KEYS, main, validate_config from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg from tests.fixtures.autouse import REPO_DIR @@ -82,7 +82,7 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path): test_cfg.max_duration = '1ba' test_cfg.eval_interval = '1ba' test_cfg.loggers = DictConfig({'inmemory': DictConfig({})}) - trainer = main(test_cfg) + trainer = train(test_cfg) assert isinstance(trainer.logger.destinations, tuple) @@ -126,7 +126,7 @@ def test_train_multi_eval(tmp_path: pathlib.Path): test_cfg.max_duration = '1ba' test_cfg.eval_interval = '1ba' test_cfg.loggers = DictConfig({'inmemory': DictConfig({})}) - trainer = main(test_cfg) + trainer = train(test_cfg) assert isinstance(trainer.logger.destinations, tuple) @@ -201,7 +201,7 @@ def test_eval_metrics_with_no_train_metrics(tmp_path: pathlib.Path): test_cfg.eval_interval = '1ba' test_cfg.loggers = DictConfig({'inmemory': DictConfig({})}) test_cfg.model['use_train_metrics'] = False - trainer = main(test_cfg) + trainer = train(test_cfg) # Check eval metrics exist inmemorylogger = trainer.logger.destinations[ diff --git a/tests/a_scripts/train/test_train_inputs.py b/tests/a_scripts/train/test_train_inputs.py index 5901d53e94..328a06a69e 100644 --- a/tests/a_scripts/train/test_train_inputs.py +++ b/tests/a_scripts/train/test_train_inputs.py @@ -9,7 +9,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om -from scripts.train.train import main # noqa: E402 +from llmfoundry.train import train # noqa: E402 def make_fake_index_file(path: str) -> None: @@ -65,7 +65,7 @@ def test_misspelled_mandatory_params_fail(self, cfg: DictConfig) -> None: with pytest.raises( (omegaconf.errors.MissingMandatoryValue, TypeError, ValueError), ): - main(cfg) + train(cfg) def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: """Check that missing mandatory parameters fail to train.""" @@ -87,7 +87,7 @@ def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: omegaconf.errors.InterpolationKeyError, omegaconf.errors.MissingMandatoryValue, )): - main(cfg) + train(cfg) cfg[param] = orig_param def test_optional_misspelled_params_raise_error( @@ -115,7 +115,7 @@ def test_optional_misspelled_params_raise_error( updated_param = param + '-misspelling' cfg[updated_param] = orig_value with pytest.raises(ValueError): - main(cfg) + train(cfg) # restore configs. cfg = copy.deepcopy(old_cfg) @@ -130,7 +130,7 @@ def test_extra_params_in_optimizer_cfg_errors( cfg.eval_loader.dataset.local = data_local cfg.optimizer.beta2 = 'extra-parameter' with pytest.raises(TypeError): - main(cfg) + train(cfg) def test_invalid_name_in_optimizer_cfg_errors( self, @@ -143,7 +143,7 @@ def test_invalid_name_in_optimizer_cfg_errors( cfg.train_loader.dataset.local = data_local cfg.eval_loader.dataset.local = data_local with pytest.raises(ValueError) as exception_info: - main(cfg) + train(cfg) assert str(exception_info.value).startswith( "Cant't find 'invalid-optimizer' in registry llmfoundry -> optimizers.", ) @@ -154,7 +154,7 @@ def test_extra_params_in_scheduler_cfg_errors( ) -> None: cfg.scheduler.t_warmup_extra = 'extra-parameter' with pytest.raises(TypeError): - main(cfg) + train(cfg) def test_invalid_name_in_scheduler_cfg_errors( self, @@ -162,7 +162,7 @@ def test_invalid_name_in_scheduler_cfg_errors( ) -> None: cfg.scheduler.name = 'invalid-scheduler' with pytest.raises(ValueError) as exception_info: - main(cfg) + train(cfg) assert str(exception_info.value).startswith( "Cant't find 'invalid-scheduler' in registry llmfoundry -> schedulers.", ) @@ -181,7 +181,7 @@ def test_no_label_multiple_eval_datasets(self, cfg: DictConfig) -> None: second_eval_loader.label = 'eval_1' cfg.eval_loader = om.create([first_eval_loader, second_eval_loader]) with pytest.raises(ValueError) as exception_info: - main(cfg) + train(cfg) assert str( exception_info.value, ) == 'When specifying multiple evaluation datasets, each one must include the \