diff --git a/.gitignore b/.gitignore index d041a25c22..5da48d2cc0 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ my-copy-c4*/ my-copy-arxiv*/ *.jsonl* +ygong/notebook/* + # WandB wandb/ @@ -156,3 +158,5 @@ notebooks/ **/mlruns/* **/tokenizer-save-dir-*/** **/.downloaded_finetuning/ + +.databricks diff --git a/llmfoundry/composerpatch/MLFlowLogger.py b/llmfoundry/composerpatch/MLFlowLogger.py new file mode 100644 index 0000000000..32635c6f7d --- /dev/null +++ b/llmfoundry/composerpatch/MLFlowLogger.py @@ -0,0 +1,31 @@ +from composer.loggers import MLFlowLogger as ComposerMLFlowLogger +from composer.utils import dist +import json +import os +from composer.core.state import State +from composer.loggers.logger import Logger + + + +CONFIG_FILE = "/tmp/mlflow_config.yaml" +EXPERIMENT_ID_FIELD = "experiment_id" +RUN_ID_FIELD = "run_id" +TRACKING_URI_FIELD = "tracking_uri" + + +class MLFlowLogger(ComposerMLFlowLogger): + + def init(self, state: State, logger: Logger) -> None: + super().init(state, logger) + + if self._enabled and dist.get_local_rank() == 0: + if os.path.exists(CONFIG_FILE): + os.remove(CONFIG_FILE) + + with open(CONFIG_FILE, "w") as f: + data = { + EXPERIMENT_ID_FIELD: self._experiment_id, + RUN_ID_FIELD: self._run_id, + TRACKING_URI_FIELD : self.tracking_uri, + } + json.dump(data, f) \ No newline at end of file diff --git a/llmfoundry/composerpatch/__init__.py b/llmfoundry/composerpatch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index fe803d62db..a00b190bec 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -14,6 +14,8 @@ from composer.core import Algorithm, Callback, Evaluator from composer.datasets.in_context_learning_evaluation import \ get_icl_task_dataloader +from composer.loggers import (InMemoryLogger, LoggerDestination, + TensorboardLogger, WandBLogger) from composer.loggers import LoggerDestination from composer.models import ComposerModel from composer.optim.scheduler import ComposerScheduler @@ -24,6 +26,8 @@ from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase +from llmfoundry.composerpatch import MLFlowLogger + from llmfoundry import registry from llmfoundry.callbacks import EvalGauntlet from llmfoundry.data.dataloader import build_dataloader @@ -236,15 +240,19 @@ def build_callback( kwargs=kwargs) -def build_logger(name: str, - kwargs: Optional[Dict[str, Any]] = None) -> LoggerDestination: - """Builds a logger from the registry.""" - return construct_from_registry(name=name, - registry=registry.loggers, - partial_function=True, - pre_validation_function=LoggerDestination, - post_validation_function=None, - kwargs=kwargs) +def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: + if name == 'wandb': + return WandBLogger(**kwargs) + elif name == 'tensorboard': + return TensorboardLogger(**kwargs) + elif name == 'in_memory_logger': + return InMemoryLogger(**kwargs) + elif name == 'mlflow': + return MLFlowLogger.MLFlowLogger(**kwargs) + elif name == 'inmemory': + return InMemoryLogger(**kwargs) + else: + raise ValueError(f'Not sure how to build logger: {name}') def build_algorithm(name: str, diff --git a/scripts/train/launcher.py b/scripts/train/launcher.py new file mode 100755 index 0000000000..e86e083e9c --- /dev/null +++ b/scripts/train/launcher.py @@ -0,0 +1,632 @@ +#!/usr/bin/env python3 +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""The Composer CLI launcher for distributed training.""" + +import contextlib +import datetime +import logging +import os +import signal +import subprocess +import sys +import tempfile +import time +import traceback +from argparse import ArgumentParser +from typing import Any, Dict, List, Union + +import psutil +import torch + +import composer +from composer.loggers.mosaicml_logger import ( + MOSAICML_GPU_LOG_FILE_PREFIX_ENV_VAR, + MOSAICML_LOG_DIR_ENV_VAR, + MOSAICML_PLATFORM_ENV_VAR, +) +from composer.utils import get_free_tcp_port +from llmfoundry.composerpatch import MLFlowLogger + + +CLEANUP_TIMEOUT = datetime.timedelta(seconds=30) + +log = logging.getLogger(__name__) + + +def _get_parser(): + parser = ArgumentParser(description='Utility for launching distributed machine learning jobs.') + + parser.add_argument('--version', action='version', version=f'MosaicML Composer {composer.__version__}') + + required_args = parser.add_argument_group('required arguments') + + parser.add_argument( + '-n', + '--nproc', + type=int, + help=( + 'The number of processes to launch on this node. Overrides env var `LOCAL_WORLD_SIZE` if specified; ' + 'otherwise, defaults to `max(1, torch.cuda.device_count())`.' + ), + ) + + parser.add_argument( + '--stdout', + type=str, + default=None, + help=( + 'Format string for a filename to dump the STDOUT from the non-local-rank-zero processes. ' + 'The local rank zero process will be piped through to STDOUT. The available format variables are: ' + "'{rank}', '{local_rank}', '{world_size}', '{node_rank}', and '{local_world_size}'. If specified, " + "it is recommended to include '{rank}' or '{local_rank}' in the filename so each rank will write to its " + 'own file. By default, the STDOUT of the non-local-rank-zero processes is discarded; instead, use the ' + 'FileLogger within Composer. This logger captures and saves the STDOUT of each process.' + ), + ) + parser.add_argument( + '--stderr', + type=str, + default=None, + help=( + 'Format string for a filename to dump the STDERR from the non-local-rank-zero processes. ' + 'The local rank zero process will be piped through to STDERR. The available format variables are: ' + "'{rank}', '{local_rank}', '{world_size}', '{node_rank}', and '{local_world_size}'. If specified, " + "it is recommended to include '{rank}' or '{local_rank}' in the filename so each rank will write to its " + 'own file. By default, the STDERR of the non-local-rank-zero processes is discarded; instead, use the ' + 'FileLogger within Composer. This logger captures and saves the STDERR of each process.' + ), + ) + parser.add_argument('-v', '--verbose', action='store_true', help='If set, print verbose messages') + parser.add_argument( + '-m', + '--module_mode', + action='store_true', + help=( + 'If set, run the training script as a module instead of as a script. ' + 'Cannot be used in conjunction with `command_mode`' + ), + ) + parser.add_argument( + '-c', + '--command_mode', + action='store_true', + help=( + 'If set, run the training script as a command (i.e. without `python`). ' + 'Cannot be used in conjunction with `module_mode`.' + ), + ) + + multinode_args = parser.add_argument_group( + 'multi-node arguments', + description=( + 'These arguments generally only need to be set when training in a multi-node ' + 'environment, i.e. when the world_size is bigger than nproc.' + ), + ) + multinode_args.add_argument( + '--world_size', + type=int, + help=( + 'The total number of processes to launch across all nodes. ' + 'Setting this to a value greater than nproc indicates a multi-node ' + 'environment. Overrides env var WORLD_SIZE. Defaults to nproc.' + ), + ) + multinode_args.add_argument( + '--base_rank', + type=int, + help=( + 'The rank of the lowest ranked process to launch on this node. ' + 'Specifying a base_rank B and an nproc N will spawn processes with ' + 'global ranks [B, B+1, ... B+N-1]. In a multi-node environment, ' + 'at least one of base_rank and node_rank must be specified. ' + 'If only one of base_rank and node_rank are provided, it is assumed ' + 'that all nodes have the same amount of processes, and that the two ' + 'values are related as node_rank * nproc = base_rank. If this is ' + 'not the case, both base_rank and node_rank must be provided. ' + 'Overrides env var BASE_RANK. Defaults to 0 in a single-node ' + 'environment.' + ), + ) + multinode_args.add_argument( + '--node_rank', + type=int, + help=( + 'The rank of this node. See base_rank for information on when ' + 'this must be provided. Overrides env var NODE_RANK. Defaults to 0 ' + 'in a single-node environment.' + ), + ) + multinode_args.add_argument( + '--master_addr', + type=str, + help=( + 'The FQDN of the node hosting the C10d TCP store. For single-node ' + 'operation, this can generally be left as 127.0.0.1. Overrides env var ' + 'MASTER_ADDR. Defaults to 127.0.0.1 in a single-node environment.' + ), + ) + multinode_args.add_argument( + '--master_port', + type=int, + help=( + 'The port on the master hosting the C10d TCP store. If you are ' + 'running multiple trainers on a single node, this generally needs ' + 'to be unique for each one. Overrides env var MASTER_PORT. Defaults ' + 'to a random free port in a single-node environment.' + ), + ) + + required_args.add_argument( + 'training_script', + type=str, + help=( + 'The path to the training script used to initialize a single training ' + 'process. Should be followed by any command-line arguments the script ' + 'should be launched with.' + ), + ) + required_args.add_argument( + 'training_script_args', + nargs='...', + help='Any arguments for the training script, given in the expected order.', + ) + + return parser + + +def _parse_args(): + parser = _get_parser() + + args = parser.parse_args() + + # Default values to env vars if they are not provided + if args.nproc is None: + if 'LOCAL_WORLD_SIZE' in os.environ: + args.nproc = int(os.environ['LOCAL_WORLD_SIZE']) + else: + args.nproc = torch.cuda.device_count() + + if args.nproc == 0: + # This could happen if doing cpu-only training, + # which could cause torch.cuda.device_count() to return 0, + # and LOCAL_WORLD_SIZE (as set by MCLI) to be zero + args.nproc = 1 + + if args.nproc < 1: + raise ValueError('The nproc must be 1 or greater') + + if args.world_size is None and 'WORLD_SIZE' in os.environ: + args.world_size = int(os.environ['WORLD_SIZE']) + + if args.base_rank is None and 'BASE_RANK' in os.environ: + args.base_rank = int(os.environ['BASE_RANK']) + + if args.node_rank is None and 'NODE_RANK' in os.environ: + args.node_rank = int(os.environ['NODE_RANK']) + + if args.master_addr is None and 'MASTER_ADDR' in os.environ: + args.master_addr = os.environ['MASTER_ADDR'] + + if args.master_port is None and 'MASTER_PORT' in os.environ: + args.master_port = int(os.environ['MASTER_PORT']) + + if args.world_size is None: + args.world_size = args.nproc + + if args.world_size < args.nproc: + raise ValueError(f'world_size({args.world_size}) cannot be less than nproc({args.nproc})') + + if args.world_size < 1: + raise ValueError('The world_size must be 1 or greater') + + is_multinode = args.world_size > args.nproc + + if is_multinode: + if args.base_rank is None and args.node_rank is None: + raise ValueError(f'In a multi-node environment, at least one of node_rank and base_rank must be provided.') + + if args.node_rank is None: + if args.world_size % args.nproc != 0 or args.base_rank % args.nproc != 0: + raise ValueError( + 'node_rank not provided, but unable to infer from base_rank since nodes appear to ' + 'have different amounts of processes. Please also specify node_rank.', + ) + args.node_rank = args.base_rank // args.nproc + + if args.base_rank is None: + if args.world_size % args.nproc != 0: + raise ValueError( + 'base_rank not provided, but unable to infer from node_rank since nodes appear to ' + 'have different amounts of processes. Please also provide base_rank.', + ) + args.base_rank = args.node_rank * args.nproc + + if args.base_rank + args.nproc > args.world_size: + raise ValueError( + f'Cannot initialize processes for node with base_rank({args.base_rank}) and ' + f'nproc({args.nproc}) because this would mean creating a process with ' + f'rank({args.base_rank + args.nproc - 1}), and all processes must have smaller rank than ' + f'the world_size({args.world_size}).', + ) + + if args.master_addr is None: + raise ValueError('In a multi-node environment, master_addr is required.') + + if args.master_port is None: + raise ValueError('In a multi-node environment, master_port is required.') + + else: + if args.base_rank is not None and args.base_rank != 0: + raise ValueError(f'base_rank({args.base_rank}) != 0 is not valid in a single-node environment.') + args.base_rank = 0 + + if args.node_rank is not None and args.node_rank != 0: + raise ValueError(f'node_rank({args.node_rank}) != 0 is not valid in a single-node environment.') + args.node_rank = 0 + + if args.master_addr is None: + args.master_addr = '127.0.0.1' + + if args.master_port is None: + args.master_port = get_free_tcp_port() + + return args + + +@contextlib.contextmanager +def _patch_env(**environs: str): + """Returns a context manager that patches ``os.environ`` with ``environs``. + + The original ``os.environ`` values are restored at the end. + """ + # Adapted loosely from https://stackoverflow.com/a/34333710 + # Capture the original environ values + original_environs = {k: os.environ.get(k) for k in environs} + + # Patch the environment + for k, v in environs.items(): + os.environ[k] = v + try: + # Run the context manager + yield + finally: + # Restore the original environ values + for k, v in original_environs.items(): + if v is None: + del os.environ[k] + else: + os.environ[k] = v + + +def _launch_processes( + nproc: int, + world_size: int, + base_rank: int, + node_rank: int, + master_addr: str, + master_port: int, + module_mode: bool, + command_mode: bool, + training_script: str, + stdout_file_format: str, + stderr_file_format: Union[str, None], + training_script_args: List[Any], + processes: Dict[int, subprocess.Popen], + log_dirs: set[str], +): + log.info('Starting distributed environment on local node for global_rank(%s-%s)', base_rank, base_rank + nproc - 1) + log.info('Distributed KV store: tcp://%s:%s', master_addr, master_port) + log.warning(f"ygong: stdout_file_format={stdout_file_format}, stderr_file_format={stderr_file_format}") + + for local_rank in range(nproc): + global_rank = base_rank + local_rank + if command_mode and module_mode: + raise ValueError('Either `command_mode` or `module_mode` should be set, but not both.') + cmd = [] + if not command_mode: + cmd.append(sys.executable) + if module_mode: + cmd.append('-m') + + cmd.append(training_script) + + # Update the env with the distributed variables + with _patch_env( + RANK=str(global_rank), + WORLD_SIZE=str(world_size), + LOCAL_RANK=str(local_rank), + LOCAL_WORLD_SIZE=str(nproc), + NODE_RANK=str(node_rank), + MASTER_ADDR=master_addr, + MASTER_PORT=str(master_port), + PYTHONUNBUFFERED='1', + NCCL_ASYNC_ERROR_HANDLING='1', + ): + # Populate the distributed variables in all launcher args + for arg in training_script_args: + cmd.append(os.path.expandvars(os.path.expanduser(arg))) + + log.info( + 'Launching process for local_rank(%s), global_rank(%s) with command(%s)', + local_rank, + global_rank, + cmd, + ) + + if local_rank == 0: + process = subprocess.Popen( + cmd, + text=True, + ) + else: + + def _get_file__(format: str): + filename = format.format( + rank=global_rank, + world_size=world_size, + local_rank=local_rank, + local_world_size=nproc, + node_rank=node_rank, + ) + dir = os.path.normpath(os.path.dirname(filename)) + os.makedirs(dir, exist_ok=True) + log_dirs.add(dir) + return open(filename, 'x+') + + stdout_file = _get_file__(stdout_file_format) + stderr_file = _get_file__(stderr_file_format) if stderr_file_format is not None else None + + process = subprocess.Popen( + cmd, + stdout=stdout_file, + stderr=stderr_file if stderr_file is not None else subprocess.STDOUT, + text=True, + ) + process.stdout = stdout_file + if stderr_file is not None: + process.stderr = stderr_file + processes[global_rank] = process + +def _logs_upload_to_mlflow(log_dirs: set[str], launcher_log: str): + import mlflow + # intialize mlflow experiment. Need to wait for processors to start the experiment + if os.environ.get("mlflow_runid") is None and os.path.exists(MLFlowLogger.CONFIG_FILE): + import json + with open(MLFlowLogger.CONFIG_FILE, "r") as f: + data = json.load(f) + log.error(f"ygong:Started mlflow run {data}") + os.environ['mlflow_runid'] = data[MLFlowLogger.RUN_ID_FIELD] + + mlflow.set_tracking_uri(data[MLFlowLogger.TRACKING_URI_FIELD]) + mlflow.start_run(run_id=data[MLFlowLogger.RUN_ID_FIELD], experiment_id=data[MLFlowLogger.EXPERIMENT_ID_FIELD]) + + # once the mlflow experiment is started, upload the logs + if os.environ.get("mlflow_runid") is not None: + try: + for log_dir in log_dirs: + log.warning(f"ygong: Logging directory: {log_dir}") + mlflow.log_artifacts(log_dir, log_dir.lstrip('/')) + mlflow.log_artifact(launcher_log) + except Exception as e: + log.error(f"ygong:Failed to log artifacts to mlflow: {e}") + + +def _monitor_processes( + processes: Dict[int, subprocess.Popen], + log_dirs: set[str], + launcher_log: str,): + import mlflow + log_frequency = 200 + cycle = 0 + + mlflow_runid = None + try: + while True: + process_has_crashed = False + all_processes_finished = True + for global_rank, process in processes.items(): + if process.poll() is None: + # the process is still running + all_processes_finished = False + continue + else: + # return code of 0 implies clean exit + if process.returncode != 0: + log.error(f'Rank {global_rank} crashed with exit code {process.returncode}.') + process_has_crashed = True + break + else: + # exited cleanly + log.info(f'Rank {global_rank} finished successfully.') + if process_has_crashed or all_processes_finished: + break + + if cycle == 0: + _logs_upload_to_mlflow(log_dirs, launcher_log) + cycle = (cycle + 1) % log_frequency + + time.sleep(0.1) + + except KeyboardInterrupt: + print('Ctrl-C received; terminating training processes.') + pass + + +def _print_process_exit_status(global_rank: int, process: subprocess.Popen): + stdOutLabel = 'STDOUT' + if process.stdout is None: + output = None + else: + process.stdout.seek(0) + output = process.stdout.read() + + if process.stderr is None: + stderr = None + stdOutLabel = 'logs' + else: + process.stderr.seek(0) + stderr = process.stderr.read() + exc = subprocess.CalledProcessError( + process.returncode, + cmd=process.args, + output=output, + stderr=stderr, + ) + + error_msg = [f'Global rank {global_rank} (PID {process.pid}) exited with code {process.returncode}'] + if output is not None: + error_msg.extend([ + f'----------Begin global rank {global_rank} {stdOutLabel}----------', + output, + f'----------End global rank {global_rank} {stdOutLabel}----------', + ]) + + if stderr is not None: + error_msg.extend([ + f'----------Begin global rank {global_rank} STDERR----------', + exc.stderr, + f'----------End global rank {global_rank} STDERR----------', + ]) + print('\n'.join(error_msg)) + + +def _cleanup_processes(processes: Dict[int, subprocess.Popen]): + for global_rank, process in processes.items(): + process.poll() + if process.returncode is None: + log.info('Killing global rank %s (PID %s) with SIGTERM', global_rank, process.pid) + # Assuming that child processes correctly handle SIGTERM to cleanup any children + try: + os.kill(process.pid, signal.SIGTERM) + except ProcessLookupError: + pass + + current_time = datetime.datetime.now() + + try: + print(( + f'Waiting up to {CLEANUP_TIMEOUT.seconds} seconds for all training processes to terminate. ' + 'Press Ctrl-C to exit immediately.' + )) + while datetime.datetime.now() - current_time < CLEANUP_TIMEOUT: + for process in processes.values(): + process.poll() + if all(process.returncode is not None for process in processes.values()): + break + time.sleep(0.1) + except KeyboardInterrupt: + pass + + for global_rank, process in processes.items(): + process.poll() + if process.returncode is None: + log.warning( + 'Failed to kill global rank %s (PID %s) with SIGTERM; terminating with SIGKILL instead', + global_rank, + process.pid, + ) + try: + proc = psutil.Process(process.pid) + except psutil.NoSuchProcess: + pass + else: + # If using SIGKILL, manually kill all child processes, since the main training process + # likely won't be able to intercept the signal and clean up its children. + for psutil_proc in [proc, *proc.children(recursive=True)]: + try: + os.kill(psutil_proc.pid, signal.SIGKILL) + except ProcessLookupError: + pass + for global_rank, process in processes.items(): + process.poll() + if process.returncode is not None and process.returncode != 0: + if -process.returncode in (signal.SIGKILL, signal.SIGTERM): + # Negative return codes indicate the process was killed via a signal + # If the launcher script killed the training process (which would happen via SIGKILL or SIGTERM), + # then do not print the stack trace. + continue + # only print the processes that have actually crashed, + # not the ones that were killed + _print_process_exit_status(global_rank, process) + + +def _aggregate_process_returncode(processes: Dict[int, subprocess.Popen]) -> int: + for global_rank, process in processes.items(): + process.poll() + if process.returncode is None: + log.error('Global rank %s (PID %s) has still not exited; return exit code 1.', global_rank, process.pid) + return 1 + if process.returncode != 0: + log.error('Global rank %s (PID %s) exited with code %s', global_rank, process.pid, process.returncode) + return process.returncode + + return 0 + + +def main(): + """Entrypoint into the Composer CLI.""" + args = _parse_args() + + log_tmpdir = tempfile.TemporaryDirectory() + launcher_log = f"{log_tmpdir.name}/launcher{args.node_rank}.log" + logging.basicConfig(filename=launcher_log, level=logging.INFO if args.verbose else logging.WARNING) + + processes = {} + log_dirs = set() + + if args.stdout is None: + args.stdout = f'{log_tmpdir.name}/rank{{rank}}.stdout.txt' + if args.stderr is None: + args.stderr = f'{log_tmpdir.name}/rank{{rank}}.stderr.txt' + + # If running on the Mosaic platform, log all gpu ranks' stderr and stdout to Mosaic platform + if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'true' and str( + os.environ.get(MOSAICML_LOG_DIR_ENV_VAR, 'false'), + ).lower() != 'false' and os.environ.get(MOSAICML_GPU_LOG_FILE_PREFIX_ENV_VAR, 'false').lower() != 'false': + log.info('Logging all GPU ranks to Mosaic Platform.') + if args.stderr is not None or args.stdout is not None: + log.info( + 'Logging to Mosaic Platform. Ignoring provided stdout and stderr args. To use provided stdout and stderr, set MOSAICML_LOG_DIR=false.', + ) + + args.stdout = f'{os.environ.get(MOSAICML_LOG_DIR_ENV_VAR)}/stdout/{os.environ.get(MOSAICML_GPU_LOG_FILE_PREFIX_ENV_VAR)}{{rank}}.txt' + args.stderr = f'{os.environ.get(MOSAICML_LOG_DIR_ENV_VAR)}/stderr/{os.environ.get(MOSAICML_GPU_LOG_FILE_PREFIX_ENV_VAR)}{{rank}}.txt' + + try: + _launch_processes( + nproc=args.nproc, + world_size=args.world_size, + base_rank=args.base_rank, + node_rank=args.node_rank, + master_addr=args.master_addr, + master_port=args.master_port, + module_mode=args.module_mode, + command_mode=args.command_mode, + stdout_file_format=args.stdout, + stderr_file_format=args.stderr, + training_script=args.training_script, + training_script_args=args.training_script_args, + processes=processes, + log_dirs=log_dirs, + ) + _monitor_processes(processes, log_dirs, launcher_log) + except: + # Print the exception first, then kill the training processes, since killing + # may take up to CLEANUP_TIMEOUT seconds, and the user should know immediately + # what failed. No need to re-raise the exception, as `aggregate_process_returncode` + # will return an appropriate error code, which will cause the script to exit. + traceback.print_exc() + print('Killing training processes') + finally: + # upload the logs before exit + _logs_upload_to_mlflow(log_dirs=log_dirs, launcher_log=launcher_log) + _cleanup_processes(processes) + log_tmpdir.cleanup() + return _aggregate_process_returncode(processes) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/scripts/train/train.py b/scripts/train/train.py index 44cfc053f4..dabf5e4a22 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -408,6 +408,7 @@ def main(cfg: DictConfig) -> Trainer: # mosaicml_logger will be None if run isn't on MosaicML platform loggers.append(mosaicml_logger) + if metadata is not None: # Flatten the metadata for logging logged_cfg.pop('metadata', None) diff --git a/setup.py b/setup.py index 22b7cb17ca..b49b80d6ef 100644 --- a/setup.py +++ b/setup.py @@ -83,10 +83,14 @@ 'pytest-cov>=4,<5', 'pyright==1.1.256', 'toml>=0.10.2,<0.11', - 'packaging>=21,<23', + 'packaging>=21', 'hf_transfer==0.1.3', ] +extra_deps['ygong'] = [ + 'databricks-genai', +] + extra_deps['databricks'] = [ 'mosaicml[databricks]>=0.21.1,<0.22', 'databricks-sql-connector>=3,<4', diff --git a/ygong/__init__.py b/ygong/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ygong/mosaic/__init__.py b/ygong/mosaic/__init__.py new file mode 100644 index 0000000000..1142627308 --- /dev/null +++ b/ygong/mosaic/__init__.py @@ -0,0 +1,6 @@ +from .submit import submit +from .submit import _set_up_environment +from .scaling_config import ScalingConfig +from .mpt125mConfig import MPT125MConfig, WSFSIntegration + +__all__ = ['submit', 'ScalingConfig', "MPT125MConfig", "WSFSIntegration", "_set_up_environment"] \ No newline at end of file diff --git a/ygong/mosaic/mpt125mConfig.py b/ygong/mosaic/mpt125mConfig.py new file mode 100644 index 0000000000..b2d5c54efd --- /dev/null +++ b/ygong/mosaic/mpt125mConfig.py @@ -0,0 +1,246 @@ +from mcli import RunConfig +from ygong.mosaic.scaling_config import ScalingConfig +from typing import Dict, List, Optional +import os +import shlex +# import databricks_genai.api.config as cfg + +class WSFSIntegration: + def __init__( + self, + wsfs_path: str, + entry_point: Optional[str] = None, + args: Optional[List[str]] = None): + """ + Class to represent the integration with Databricks WSFS. + + :params: wsfs_path: str Absolute path + :params: entry_point: str Required if the wsfs_path is a directory + """ + self.wsfs_path = wsfs_path + self.entry_point = entry_point + self.args = args + + def get_entry_command(self): + entry_file_path = "" + if self.entry_point is not None: + if self.entry_point.startswith("/Workspace"): + entry_file_path = self.entry_point + else: + entry_file_path = os.path.join(self.wsfs_path, self.entry_point) + else: + entry_file_path = self.wsfs_path + if self.args is None: + return f"python3 {shlex.quote(entry_file_path)}" + return f"python3 {shlex.quote(entry_file_path)} {' '.join(self.args)}" + + def toDict(self): + return { + "integration_type": "wsfs", + "wsfs_path": self.wsfs_path, + "entrypoint": self.entry_point, + "args": self.args, + } + + +class MPT125MConfig: + def __init__( + self, + experimentName: str, + data: str, + priority: str = 'high', + preemptible: bool = False, + retry_on_system_failure: bool = False, + wsfs_integration: Optional[WSFSIntegration] = None): + # TODO: validate the inputs and remove the yu.gong hardcode + self.mlflow_experimentName = f"/Users/yu.gong@databricks.com/{experimentName}" + self.mlflow_trackingUri = "databricks" + # self.mlflow_trackingUri = "databricks" if workspace_url is None else workspace_url + + self.data = data + + # Scheudling parameters + self.priority = priority + self.preemptible = preemptible + self.retry_on_system_failure = retry_on_system_failure + + # the run name is pre-configured for all config-driven pretrain runs + self.name = "mpt125m-config-driven-pretrain" + + ######################################## + # model parameters + ######################################## + self.max_seq_len = 2048 + self.global_seed = 17 + self.data_remote = self.data + self.data_local = "" + self.commands = [] + if wsfs_integration is not None: + # The first group of commands are to download the object(file or directory) from + # databricks WSFS using PAT token and url. + # The second command try to unzip if the object from WSFS is directory. + # TODO: Read the token and host name from env vars or /mnt/jwt-secret/.databrickscfg + self.commands = [ + f""" + DATABRICKS_HOST="https://oregon.staging.cloud.databricks.com" + DATABRICKS_TOKEN="dapid5af61ff89674be90c3e86ae9fc10c2e" + WSFS_PATH="{wsfs_integration.wsfs_path}" + DIR_NAME=$(dirname "$WSFS_PATH") + ENCODED_WSFS_PATH=$(python3 -c "import urllib.parse; print(urllib.parse.quote('$WSFS_PATH'))") + mkdir -p "$DIR_NAME" + curl -X GET -o "$WSFS_PATH" "${{DATABRICKS_HOST}}/api/2.0/workspace/export?path=${{ENCODED_WSFS_PATH}}&direct_download=true" \ + -H "Authorization: Bearer $DATABRICKS_TOKEN" + + if file "$WSFS_PATH" | grep -q "Zip archive data"; then + mv "$WSFS_PATH" "${{WSFS_PATH}}.zip" + apt update && apt install unzip + unzip -d "$DIR_NAME" "${{WSFS_PATH}}.zip" + rm -f "${{WSFS_PATH}}.zip" + else + echo "$WSFS_PATH is not a ZIP file." + fi + """ + ] + self.commands.append(wsfs_integration.get_entry_command()) + else: + self.commands = [ + "cd llm-foundry/scripts", + "train/launcher.py train/train.py /mnt/config/parameters.yaml train_loader.dataset.split=train eval_loader.dataset.split=val" + ] + + + def toRunConfig(self, scalingConfig: ScalingConfig): + return RunConfig( + name=self.name, + image='mosaicml/llm-foundry:2.2.1_cu121_flash2-latest', + command="\n".join(self.commands), + compute=scalingConfig.toCompute, + scheduling={ + 'priority': self.priority, + 'preemptible': self.preemptible, + 'retry_on_system_failure': self.retry_on_system_failure + }, + integrations=[ + { + 'integration_type': 'git_repo', + 'git_repo': 'shitaoli-db/llm-foundry', + 'git_branch': 'shitao.li@databricks.com/prototype-shitao', + 'pip_install': '-e .[gpu]', + 'ssh_clone': False + }, + { + 'integration_type': 'pip_packages', + 'packages': ['pynvml', 'mosaicml-streaming[databricks]'], + }, + ], + parameters=self.parameters(), + env_variables={}, + ) + + def parameters(self): + return { + "data_local": self.data_local, + "data_remote": self.data, + "max_seq_len": self.max_seq_len, + "global_seed": self.global_seed, + "run_name": None, + "model": { + "name": "mpt_causal_lm", + "init_device": "meta", + "d_model": 768, + "n_heads": 12, + "n_layers": 12, + "expansion_ratio": 4, + "max_seq_len": self.max_seq_len, + "vocab_size": 50368, + "attn_config": { + "attn_impl": "flash" + } + }, + "tokenizer": { + "name": "EleutherAI/gpt-neox-20b", + "kwargs": { + "model_max_length": self.max_seq_len + } + }, + "train_loader": { + "name": "text", + "dataset": { + "local": f"{self.data_local}", + "remote": f"{self.data_remote}", + "split": "train", + "shuffle": True, + "max_seq_len": self.max_seq_len, + "shuffle_seed": self.global_seed + }, + "drop_last": True, + "num_workers": 8 + }, + "eval_loader": { + "name": "text", + "dataset": { + "local": f"{self.data_local}", + "remote": f"{self.data_remote}", + "split": "val", + "shuffle": False, + "max_seq_len": self.max_seq_len, + "shuffle_seed": self.global_seed + }, + "drop_last": False, + "num_workers": 8 + }, + "scheduler": { + "name": "cosine_with_warmup", + "t_warmup": "100ba", + "alpha_f": 0.1 + }, + "optimizer": { + "name": "decoupled_adamw", + "lr": 6.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-08, + "weight_decay": 0.0 + }, + "algorithms": { + "gradient_clipping": { + "clipping_type": "norm", + "clipping_threshold": 1.0 + } + }, + "max_duration": "480ba", # ~ 2.5B tokens, original + "eval_interval": "50ba", # original 500 + "eval_first": False, + "eval_subset_num_batches": -1, + "global_train_batch_size": 256, + "seed": self.global_seed, + "device_eval_batch_size": 16, + "device_train_microbatch_size": 16, + "precision": "amp_bf16", + "fsdp_config": { + "sharding_strategy": "FULL_SHARD", + "mixed_precision": "PURE", + "activation_checkpointing": False, + "activation_checkpointing_reentrant": False, + "activation_cpu_offload": False, + "limit_all_gathers": True + }, + "progress_bar": False, + "log_to_console": True, + "console_log_interval": "10ba", + "callbacks": { + "speed_monitor": { + "window_size": 10 + }, + "lr_monitor": {}, + "memory_monitor": {}, + "runtime_estimator": {} + }, + "loggers": { + "mlflow": { + "experiment_name": self.mlflow_experimentName, + "tracking_uri": "databricks", + "synchronous": False, + "log_system_metrics": True + } + } + } \ No newline at end of file diff --git a/ygong/mosaic/scaling_config.py b/ygong/mosaic/scaling_config.py new file mode 100644 index 0000000000..f22e8e7e4c --- /dev/null +++ b/ygong/mosaic/scaling_config.py @@ -0,0 +1,14 @@ +class ScalingConfig: + def __init__(self, gpusNum: int, gpuType: str, poolName: str): + # TODO: validate the inputs + self.gpusNum = gpusNum + self.gpuType = gpuType + self.poolName = poolName + + @property + def toCompute(self): + return { + 'gpus': self.gpusNum, + 'gpu_type': self.gpuType, + 'cluster': self.poolName + } diff --git a/ygong/mosaic/submit.py b/ygong/mosaic/submit.py new file mode 100644 index 0000000000..c1a25245a6 --- /dev/null +++ b/ygong/mosaic/submit.py @@ -0,0 +1,202 @@ +from ygong.mosaic.scaling_config import ScalingConfig +from ygong.mosaic.mpt125mConfig import MPT125MConfig + +from databricks.sdk import WorkspaceClient +from mcli import config, Run, RunStatus, create_run +from mcli.api.runs.api_get_runs import get_run +from mcli.cli.m_get.runs import RunDisplayItem +from IPython.display import display, clear_output, HTML +import ipywidgets as widgets +import mlflow +import pandas as pd + +from typing import Optional +import base64 +import time +import json +import logging +import os +import sys +from mcli import RunConfig +import hashlib +from mcli.config import MCLIConfig +from mcli.api.engine.engine import MAPIConnection + +logger = logging.getLogger('ygong.mosaic.submit') + +def _set_up_environment(content: str): + os.environ['CREDENTIALS'] = content + + +def _init_connection(): + def _is_local(): + if os.environ.get('CREDENTIALS') is not None: + return True + try: + wc = WorkspaceClient() + wc.dbutils.entry_point.getDbutils().notebook().getContext() + return False + except: + return True + + if _is_local(): + logger.debug("init_connection in local mode") + if os.environ.get('CREDENTIALS') is None: + raise ValueError("_set_up_environment must be manually called to configure credentials for local runs") + data = json.loads(base64.b64decode(os.environ.get('CREDENTIALS')).decode('utf-8')) + workspace_url = data.get("workspace_url", None) + token = data.get("token", None) + # set up the mosaic token + os.environ[config.MCLI_MODE_ENV] = config.MCLIMode.DBX_AWS_STAGING.value + os.environ[config.MOSAICML_ACCESS_TOKEN_FILE_ENV] = "/home/shitao.li/e2_token" + else: + logger.debug("init_connection in databricks environment") + wc = WorkspaceClient() + ctx = wc.dbutils.entry_point.getDbutils().notebook().getContext() + token = ctx.apiToken().get() + api_url = ctx.apiUrl().get() + endpoint = f'{api_url}/api/2.0/genai-mapi/graphql' + workspace_url = api_url + os.environ[config.MOSAICML_API_KEY_ENV] = f'Bearer {token}' + os.environ[config.MOSAICML_API_ENDPOINT_ENV] = endpoint + try: + jobs_id = ctx.jobId().get() + os.environ['JOB_ID'] = jobs_id + except: + pass + + # needed to set up the MLFlow query for experiment runs + os.environ['WORKSPACE_URL'] = workspace_url + os.environ['MLFLOW_TRACKING_TOKEN'] = token + logger.debug(f"init_connection token: {os.environ['MLFLOW_TRACKING_TOKEN']}, workspace: {os.environ['WORKSPACE_URL']}, is_jobs: {os.environ.get('JOB_ID')}") + + +def get_experiment_run_url(tracking_uri: Optional[str], experiment_name: str, run_name: str): + if tracking_uri is None: + raise ValueError("tracking_uri must be provided") + mlflow.set_tracking_uri(tracking_uri) + tracking_uri = tracking_uri.rstrip("/") + experiment = mlflow.get_experiment_by_name(name=experiment_name) + if experiment is None: + raise ValueError(f"experiment {experiment_name} does not exist") + experiment_id = experiment.experiment_id + runs = mlflow.search_runs(experiment_ids=[experiment_id], + filter_string=f'tags.composer_run_name = "{run_name}"', + output_format='list') + if len(runs) == 0: + return None + elif len(runs) > 1: + raise ValueError(f"multiple runs {run_name} exist in experiment {experiment_name}") + else: + run_id = runs[0].info.run_id + return f"{tracking_uri}/ml/experiments/{experiment_id}/runs/{run_id}" + + +def _get_run_summary(run: Run, experiment_name: Optional[str] = None): + url = None + + run_rows = [] + + # Copy pasted from mcli to display the the resumption status of the run. + for row_raw in RunDisplayItem.from_run(run, [], True): + row = row_raw.to_dict() + if row['Status'].startswith('Running') and experiment_name is not None: + url = get_experiment_run_url(os.environ.get('WORKSPACE_URL'), experiment_name, run.name) + row['Experiment Run'] =f'Link' if url is not None else "" + run_rows.append(row) + + df = pd.DataFrame(run_rows) + return df + +def _display_run_summary(summary: pd.DataFrame, cancel_button: Optional[widgets.Button]): + clear_output(wait=True) + if cancel_button is not None: + display(cancel_button) + display(HTML(summary.to_html(escape=False))) + +def _wait_for_run_status(run: Run, status: RunStatus, inclusive: bool = True): + run_name = run.name + while not run.status.after(status, inclusive=inclusive): + time.sleep(5) + run = get_run(run_name) + logger.debug(f"run status {run.status}, expected status {status}") + logger.debug(f"finish waiting run reached expected status {status}") + return run + +def submit(config: any, scalingConfig: ScalingConfig, sync: bool = False, debug: bool = False): + if debug: + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.DEBUG) # Set minimum log level for the handler + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + stdout_handler.setFormatter(formatter) + + # Add the handler to the logger + logger.addHandler(stdout_handler) + logger.setLevel(logging.DEBUG) + + logger.info("set the logger to debug mode") + + # MTC + AWS Dogfood + _init_connection() + mlflow_experiment_name = None + if isinstance(config, MPT125MConfig): + mlflow_experiment_name = config.mlflow_experimentName + runConfig = config.toRunConfig(scalingConfig) + elif isinstance(config, RunConfig): + runConfig = config + mlflow_experiment_name = runConfig.name + else: + raise ValueError(f"config type {type(config)} is not supported") + + run = create_run(runConfig) + run_name = run.name + # Create a button + if os.environ.get('JOB_ID') is not None: + # running in jobs workflow, no need to cancel the run and doesn't support widgets + button = None + else: + button = widgets.Button(description="cancel the run") + def on_button_clicked(b): + logger.debug(f"cancel button clicked") + clear_output(wait=False) + run = get_run(run_name) + run.stop() + logger.debug(f"run {run_name} is cancelled") + run = _wait_for_run_status(run, RunStatus.TERMINATING) + summary = _get_run_summary(run, mlflow_experiment_name) + display(HTML(summary.to_html(escape=False))) + button.on_click(on_button_clicked) + _display_run_summary(_get_run_summary(run, mlflow_experiment_name), button) + run = _wait_for_run_status(run, RunStatus.RUNNING) + + def _wait_for_run_finish(run: Run): + run_name = run.name + while not run.status.is_terminal(): + run = get_run(run_name) + _display_run_summary(_get_run_summary(run, mlflow_experiment_name), button) + time.sleep(5) + logger.debug(f"finish waiting run reached terminal") + return run + + try_count = 0 + while try_count < 10: + try_count += 1 + time.sleep(20) + try: + run = get_run(run) + if run.status.is_terminal(): + logger.debug(f"run {run_name} is in terminal state. Status {run.status}") + break + summary = _get_run_summary(run, mlflow_experiment_name) + _display_run_summary(summary, button) + break + except ValueError: + logger.debug(f"waiting for the MLFLow experiment run to be ready, run status{run.status}") + pass + + if sync: + logger.debug(f"synchronously waiting for the run to finish.") + run = _wait_for_run_finish(run) + _display_run_summary(_get_run_summary(run, mlflow_experiment_name), None) + + return run \ No newline at end of file