diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 496e905e13..8c86dda2a6 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -16,6 +16,7 @@ from llmfoundry.callbacks.async_eval_callback import AsyncEval from llmfoundry.callbacks.curriculum_learning_callback import CurriculumLearning +from llmfoundry.callbacks.env_logging_callback import EnvironmentLoggingCallback from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet from llmfoundry.callbacks.eval_output_logging_callback import EvalOutputLogging from llmfoundry.callbacks.fdiff_callback import FDiffMetrics @@ -55,8 +56,8 @@ callbacks.register('eval_output_logging', func=EvalOutputLogging) callbacks.register('mbmoe_tok_per_expert', func=MegaBlocksMoE_TokPerExpert) callbacks.register('run_timeout', func=RunTimeoutCallback) - callbacks.register('loss_perp_v_len', func=LossPerpVsContextLengthLogger) +callbacks.register('env_logger', func=EnvironmentLoggingCallback) callbacks_with_config.register('async_eval', func=AsyncEval) callbacks_with_config.register('curriculum_learning', func=CurriculumLearning) diff --git a/llmfoundry/callbacks/env_logging_callback.py b/llmfoundry/callbacks/env_logging_callback.py new file mode 100644 index 0000000000..a192390976 --- /dev/null +++ b/llmfoundry/callbacks/env_logging_callback.py @@ -0,0 +1,188 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import platform +import socket +from typing import Any, Optional + +import git +import pkg_resources +import psutil +import torch +from composer.core import Callback, State +from composer.loggers import Logger +from composer.utils import dist + +from mcli import sdk + +__all__ = ['EnvironmentLoggingCallback'] + +_PACKAGES_TO_LOG = [ + 'llm-foundry', + 'mosaicml', + 'megablocks', + 'grouped-gemm', + 'torch', + 'flash_attn', + 'transformers', + 'datasets', + 'peft', +] + + +class EnvironmentLoggingCallback(Callback): + """A callback for logging environment information during model training. + + This callback collects various pieces of information about the training environment, + including git repository details, package versions, system information, GPU details, + distributed training setup, NVIDIA driver information, and Docker container details. + + Args: + workspace_dir (str): The directory containing the workspace. Defaults to '/workspace'. + log_git (bool): Whether to log git repository information. Defaults to True. + log_packages (bool): Whether to log package versions. Defaults to True. + log_nvidia (bool): Whether to log NVIDIA driver information. Defaults to True. + log_docker (bool): Whether to log Docker container information. Defaults to True. + log_system (bool): Whether to log system information. Defaults to False. + log_gpu (bool): Whether to log GPU information. Defaults to False. + log_distributed (bool): Whether to log distributed training information. Defaults to False. + packages_to_log (list[str]): A list of package names to log versions for. Defaults to None. + + The collected information is logged as hyperparameters at the start of model fitting. + """ + + def __init__( + self, + workspace_dir: str = '/workspace', + log_git: bool = True, + log_nvidia: bool = True, + log_docker: bool = True, + log_packages: bool = True, + log_system: bool = False, + log_gpu: bool = False, + log_distributed: bool = False, + packages_to_log: Optional[list[str]] = None, + ): + self.workspace_dir = workspace_dir + self.log_git = log_git + self.log_packages = log_packages + self.log_nvidia = log_nvidia + self.log_docker = log_docker + self.log_system = log_system + self.log_gpu = log_gpu + self.log_distributed = log_distributed + self.env_data: dict[str, Any] = {} + self.packages_to_log = packages_to_log or _PACKAGES_TO_LOG + + def _get_git_info(self, repo_path: str) -> Optional[dict[str, str]]: + if not os.path.isdir(repo_path): + return None + try: + repo = git.Repo(repo_path) + return { + 'commit_hash': repo.head.commit.hexsha, + 'branch': repo.active_branch.name, + } + except (git.InvalidGitRepositoryError, git.NoSuchPathError): + return None + + def _get_package_version(self, package_name: str) -> Optional[str]: + try: + return pkg_resources.get_distribution(package_name).version + except pkg_resources.DistributionNotFound: + return None + + def _get_system_info(self) -> dict[str, Any]: + return { + 'python_version': platform.python_version(), + 'os': f'{platform.system()} {platform.release()}', + 'hostname': socket.gethostname(), + 'cpu_info': { + 'model': platform.processor(), + 'cores': psutil.cpu_count(logical=False), + 'threads': psutil.cpu_count(logical=True), + }, + 'memory': { + 'total': psutil.virtual_memory().total, + 'available': psutil.virtual_memory().available, + }, + } + + def _get_gpu_info(self) -> dict[str, Any]: + if torch.cuda.is_available(): + return { + 'model': torch.cuda.get_device_name(0), + 'count': torch.cuda.device_count(), + 'memory': { + 'total': torch.cuda.get_device_properties(0).total_memory, + 'allocated': torch.cuda.memory_allocated(0), + }, + } + return {'available': False} + + def _get_nvidia_info(self) -> dict[str, Any]: + if torch.cuda.is_available(): + nccl_version = torch.cuda.nccl.version() # type: ignore + return { + 'cuda_version': + torch.version.cuda, # type: ignore[attr-defined] + 'cudnn_version': str( + torch.backends.cudnn.version(), + ), # type: ignore[attr-defined] + 'nccl_version': '.'.join(map(str, nccl_version)), + } + return {'available': False} + + def _get_distributed_info(self) -> dict[str, Any]: + return { + 'world_size': dist.get_world_size(), + 'local_world_size': dist.get_local_world_size(), + 'rank': dist.get_global_rank(), + 'local_rank': dist.get_local_rank(), + } + + def _get_docker_info(self) -> Optional[dict[str, Any]]: + if 'RUN_NAME' not in os.environ: + return None + run = sdk.get_run(os.environ['RUN_NAME']) + image, tag = run.image.split(':') + return { + 'image': image, + 'tag': tag, + } + + def fit_start(self, state: State, logger: Logger) -> None: + # Collect environment data + if self.log_git: + self.env_data['git_info'] = {} + for folder in os.listdir(self.workspace_dir): + path = self._get_git_info( + os.path.join(self.workspace_dir, folder), + ) + if path: + self.env_data['git_info'][folder] = path + + if self.log_packages: + self.env_data['package_versions'] = { + pkg: self._get_package_version(pkg) + for pkg in self.packages_to_log + } + if self.log_nvidia: + self.env_data['nvidia'] = self._get_nvidia_info() + + if self.log_docker: + if docker_info := self._get_docker_info(): + self.env_data['docker'] = docker_info + + if self.log_system: + self.env_data['system_info'] = self._get_system_info() + + if self.log_gpu: + self.env_data['gpu_info'] = self._get_gpu_info() + + if self.log_distributed: + self.env_data['distributed_info'] = self._get_distributed_info() + + # Log the collected data + logger.log_hyperparameters({'environment_data': self.env_data}) diff --git a/setup.py b/setup.py index 127cde50c5..416a6db4a4 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,7 @@ 'tenacity>=8.2.3,<9', 'catalogue>=2,<3', 'typer<1', + 'GitPython==3.1.43', ] extra_deps = {}