From 2d65fc20d88ce77c5c012566fa8b94766c543782 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 25 Mar 2024 13:11:12 -0700 Subject: [PATCH] Models registry (#1057) --- llmfoundry/__init__.py | 2 - .../callbacks/curriculum_learning_callback.py | 4 +- llmfoundry/models/__init__.py | 18 ++++++ llmfoundry/models/hf/hf_t5.py | 4 +- .../models/inference_api_wrapper/fmapi.py | 16 ++--- .../models/inference_api_wrapper/interface.py | 8 ++- .../inference_api_wrapper/openai_causal_lm.py | 28 +++++---- llmfoundry/models/model_registry.py | 21 ------- llmfoundry/optim/scheduler.py | 4 +- llmfoundry/registry.py | 12 ++++ llmfoundry/utils/builders.py | 58 ++++++++++++++++++- llmfoundry/utils/warnings.py | 20 ++++++- scripts/eval/eval.py | 47 ++++----------- scripts/inference/benchmarking/benchmark.py | 15 ++++- scripts/train/train.py | 34 ++++------- tests/a_scripts/eval/test_eval.py | 10 +++- .../inference/test_convert_composer_to_hf.py | 25 +++++--- tests/fixtures/models.py | 9 ++- tests/models/hf/test_fsdp_weight_tying.py | 10 ++-- tests/models/hf/test_hf_config.py | 29 +++++++--- tests/models/hf/test_hf_peft_wrapping.py | 10 ++-- tests/models/hf/test_hf_v_mpt.py | 28 ++++++--- .../inference_api_wrapper/test_fmapi.py | 8 +-- .../test_inference_api_eval_wrapper.py | 6 +- tests/models/layers/test_huggingface_flash.py | 9 ++- tests/models/test_model.py | 48 +++++++++++---- tests/test_registry.py | 1 + 27 files changed, 307 insertions(+), 177 deletions(-) delete mode 100644 llmfoundry/models/model_registry.py diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 0529db7265..ea9eba74ce 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -31,7 +31,6 @@ flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn -from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig, MPTForCausalLM, MPTModel, MPTPreTrainedModel) from llmfoundry.tokenizers import TiktokenTokenizerWrapper @@ -53,7 +52,6 @@ 'ComposerHFCausalLM', 'ComposerHFPrefixLM', 'ComposerHFT5', - 'COMPOSER_MODEL_REGISTRY', 'scaled_multihead_dot_product_attention', 'flash_attn_fn', 'triton_flash_attn_fn', diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 9e0de424e5..37faa14fdd 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -16,12 +16,12 @@ from torch.utils.data import DataLoader from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.utils.warnings import experimental +from llmfoundry.utils.warnings import experimental_class log = logging.getLogger(__name__) -@experimental('CurriculumLearning callback') +@experimental_class('CurriculumLearning callback') class CurriculumLearning(CallbackWithConfig): """Starts an epoch with a different dataset when resuming from a checkpoint. diff --git a/llmfoundry/models/__init__.py b/llmfoundry/models/__init__.py index 392d7d3c3c..36234d3c14 100644 --- a/llmfoundry/models/__init__.py +++ b/llmfoundry/models/__init__.py @@ -3,8 +3,22 @@ from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM, ComposerHFT5) +from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper, + FMAPIChatAPIEvalWrapper, + OpenAICausalLMEvalWrapper, + OpenAIChatAPIEvalWrapper) from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig, MPTForCausalLM, MPTModel, MPTPreTrainedModel) +from llmfoundry.registry import models + +models.register('mpt_causal_lm', func=ComposerMPTCausalLM) +models.register('hf_causal_lm', func=ComposerHFCausalLM) +models.register('hf_prefix_lm', func=ComposerHFPrefixLM) +models.register('hf_t5', func=ComposerHFT5) +models.register('openai_causal_lm', func=OpenAICausalLMEvalWrapper) +models.register('fmapi_causal_lm', func=FMAPICasualLMEvalWrapper) +models.register('openai_chat', func=OpenAIChatAPIEvalWrapper) +models.register('fmapi_chat', func=FMAPIChatAPIEvalWrapper) __all__ = [ 'ComposerHFCausalLM', @@ -15,4 +29,8 @@ 'MPTModel', 'MPTForCausalLM', 'ComposerMPTCausalLM', + 'OpenAICausalLMEvalWrapper', + 'FMAPICasualLMEvalWrapper', + 'OpenAIChatAPIEvalWrapper', + 'FMAPIChatAPIEvalWrapper', ] diff --git a/llmfoundry/models/hf/hf_t5.py b/llmfoundry/models/hf/hf_t5.py index d059f54f85..5956d49cc8 100644 --- a/llmfoundry/models/hf/hf_t5.py +++ b/llmfoundry/models/hf/hf_t5.py @@ -17,12 +17,12 @@ from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss from llmfoundry.models.utils import (adapt_tokenizer_for_denoising, init_empty_weights) -from llmfoundry.utils.warnings import experimental +from llmfoundry.utils.warnings import experimental_class __all__ = ['ComposerHFT5'] -@experimental('ComposerHFT5') +@experimental_class('ComposerHFT5') class ComposerHFT5(HuggingFaceModelWithZLoss): """Configures a :class:`.HuggingFaceModel` around a T5. diff --git a/llmfoundry/models/inference_api_wrapper/fmapi.py b/llmfoundry/models/inference_api_wrapper/fmapi.py index e74d2c3849..58ea302ace 100644 --- a/llmfoundry/models/inference_api_wrapper/fmapi.py +++ b/llmfoundry/models/inference_api_wrapper/fmapi.py @@ -4,9 +4,9 @@ import logging import os import time -from typing import Dict import requests +from omegaconf import DictConfig from transformers import AutoTokenizer from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( @@ -25,7 +25,7 @@ class FMAPIEvalInterface(OpenAIEvalInterface): def block_until_ready(self, base_url: str): """Block until the endpoint is ready.""" sleep_s = 5 - timout_s = 5 * 60 # At max, wait 5 minutes + timeout_s = 5 * 60 # At max, wait 5 minutes ping_url = f'{base_url}/ping' @@ -42,25 +42,25 @@ def block_until_ready(self, base_url: str): time.sleep(sleep_s) waited_s += sleep_s - if waited_s >= timout_s: + if waited_s >= timeout_s: raise TimeoutError( f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting' ) - def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer): - is_local = model_cfg.pop('local', False) + def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer): + is_local = om_model_config.pop('local', False) if is_local: base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT', 'http://0.0.0.0:8080/v2') - model_cfg['base_url'] = base_url + om_model_config['base_url'] = base_url self.block_until_ready(base_url) - if 'base_url' not in model_cfg: + if 'base_url' not in om_model_config: raise ValueError( 'Must specify base_url or use local=True in model_cfg for FMAPIsEvalWrapper' ) - super().__init__(model_cfg, tokenizer) + super().__init__(om_model_config, tokenizer) class FMAPICasualLMEvalWrapper(FMAPIEvalInterface, OpenAICausalLMEvalWrapper): diff --git a/llmfoundry/models/inference_api_wrapper/interface.py b/llmfoundry/models/inference_api_wrapper/interface.py index 3ee57d2f46..4c30e7822d 100644 --- a/llmfoundry/models/inference_api_wrapper/interface.py +++ b/llmfoundry/models/inference_api_wrapper/interface.py @@ -1,22 +1,24 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from composer.core.types import Batch from composer.metrics import InContextLearningMetric from composer.models import ComposerModel +from omegaconf import DictConfig from torchmetrics import Metric from transformers import AutoTokenizer from llmfoundry.metrics import DEFAULT_CAUSAL_LM_EVAL_METRICS -from llmfoundry.utils.builders import build_metric class InferenceAPIEvalWrapper(ComposerModel): - def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer): + def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer): + from llmfoundry.utils.builders import build_metric + self.tokenizer = tokenizer self.labels = None eval_metrics = [ diff --git a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py index c260a06aff..bacf71b8e2 100644 --- a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +++ b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py @@ -12,6 +12,7 @@ import torch from composer.core.types import Batch from composer.utils.import_helpers import MissingConditionalImportError +from omegaconf import DictConfig from transformers import AutoTokenizer log = logging.getLogger(__name__) @@ -34,8 +35,9 @@ class OpenAIEvalInterface(InferenceAPIEvalWrapper): - def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: - super().__init__(model_cfg, tokenizer) + def __init__(self, om_model_config: DictConfig, + tokenizer: AutoTokenizer) -> None: + super().__init__(om_model_config, tokenizer) try: import openai except ImportError as e: @@ -45,7 +47,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: conda_channel='conda-forge') from e api_key = os.environ.get('OPENAI_API_KEY') - base_url = model_cfg.get('base_url') + base_url = om_model_config.get('base_url') if base_url is None: # Using OpenAI default, where the API key is required if api_key is None: @@ -61,10 +63,10 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: api_key = 'placeholder' # This cannot be None self.client = openai.OpenAI(base_url=base_url, api_key=api_key) - if 'version' in model_cfg: - self.model_name = model_cfg['version'] + if 'version' in om_model_config: + self.model_name = om_model_config['version'] else: - self.model_name = model_cfg['name'] + self.model_name = om_model_config['name'] def generate_completion(self, prompt: str, num_tokens: int): raise NotImplementedError() @@ -109,8 +111,9 @@ def try_generate_completion(self, prompt: str, num_tokens: int): class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface): - def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: - super().__init__(model_cfg, tokenizer) + def __init__(self, om_model_config: DictConfig, + tokenizer: AutoTokenizer) -> None: + super().__init__(om_model_config, tokenizer) self.generate_completion = lambda prompt, num_tokens: self.client.chat.completions.create( model=self.model_name, @@ -118,8 +121,8 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: 'role': 'system', 'content': - model_cfg.get('system_role_prompt', - 'Please complete the following text: ') + om_model_config.get('system_role_prompt', + 'Please complete the following text: ') }, { 'role': 'user', 'content': prompt @@ -244,8 +247,9 @@ def process_result(self, completion: Optional['ChatCompletion']): class OpenAICausalLMEvalWrapper(OpenAIEvalInterface): - def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: - super().__init__(model_cfg, tokenizer) + def __init__(self, om_model_config: DictConfig, + tokenizer: AutoTokenizer) -> None: + super().__init__(om_model_config, tokenizer) self.generate_completion = lambda prompt, num_tokens: self.client.completions.create( model=self.model_name, prompt=prompt, diff --git a/llmfoundry/models/model_registry.py b/llmfoundry/models/model_registry.py deleted file mode 100644 index ff9942f5f6..0000000000 --- a/llmfoundry/models/model_registry.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM, - ComposerHFT5) -from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper, - FMAPIChatAPIEvalWrapper, - OpenAICausalLMEvalWrapper, - OpenAIChatAPIEvalWrapper) -from llmfoundry.models.mpt import ComposerMPTCausalLM - -COMPOSER_MODEL_REGISTRY = { - 'mpt_causal_lm': ComposerMPTCausalLM, - 'hf_causal_lm': ComposerHFCausalLM, - 'hf_prefix_lm': ComposerHFPrefixLM, - 'hf_t5': ComposerHFT5, - 'openai_causal_lm': OpenAICausalLMEvalWrapper, - 'fmapi_causal_lm': FMAPICasualLMEvalWrapper, - 'openai_chat': OpenAIChatAPIEvalWrapper, - 'fmapi_chat': FMAPIChatAPIEvalWrapper, -} diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py index 3aefbaa875..655093d138 100644 --- a/llmfoundry/optim/scheduler.py +++ b/llmfoundry/optim/scheduler.py @@ -11,7 +11,7 @@ from composer.optim import ComposerScheduler, LinearScheduler from composer.optim.scheduler import _convert_time -from llmfoundry.utils.warnings import experimental +from llmfoundry.utils.warnings import experimental_class __all__ = ['InverseSquareRootWithWarmupScheduler'] @@ -34,7 +34,7 @@ def _raise_if_units_dur(time: Union[str, Time], name: str) -> None: raise ValueError(f'{name} cannot be in units of "dur".') -@experimental('InverseSquareRootWithWarmupScheduler') +@experimental_class('InverseSquareRootWithWarmupScheduler') class InverseSquareRootWithWarmupScheduler(ComposerScheduler): r"""Inverse square root LR decay with warmup and optional linear cooldown. diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 897f714d62..e289a923b6 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -4,6 +4,7 @@ from composer.core import Algorithm, Callback, DataSpec from composer.loggers import LoggerDestination +from composer.models import ComposerModel from composer.optim import ComposerScheduler from omegaconf import DictConfig from torch.optim import Optimizer @@ -83,6 +84,15 @@ entry_points=True, description=_schedulers_description) +_models_description = """The models registry is used to register classes that implement the ComposerModel interface. The model +constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. +Note: This will soon be updated to take in named kwargs instead of a config directly.""" +models = create_registry('llmfoundry', + 'models', + generic_type=Type[ComposerModel], + entry_points=True, + description=_models_description) + _dataloaders_description = """The dataloaders registry is used to register functions that create a DataSpec. The function should take a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.""" dataloaders = create_registry( @@ -106,5 +116,7 @@ 'optimizers', 'algorithms', 'schedulers', + 'models', 'metrics', + 'dataloaders', ] diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index b064b00759..fe803d62db 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -1,18 +1,21 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import functools import logging import os import re from collections import OrderedDict -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import (Any, ContextManager, Dict, Iterable, List, Optional, Tuple, + Union) import torch from composer.core import Algorithm, Callback, Evaluator from composer.datasets.in_context_learning_evaluation import \ get_icl_task_dataloader from composer.loggers import LoggerDestination +from composer.models import ComposerModel from composer.optim.scheduler import ComposerScheduler from composer.utils import dist from omegaconf import DictConfig, ListConfig @@ -39,6 +42,7 @@ 'build_optimizer', 'build_scheduler', 'build_tokenizer', + 'build_composer_model', 'build_metric', ] @@ -155,6 +159,58 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb +def build_composer_model( + name: str, + cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, + init_context: Optional[ContextManager] = None, + master_weights_dtype: Optional[str] = None, +) -> ComposerModel: + """Builds a ComposerModel from the registry. + + Args: + name (str): Name of the model to build. + cfg (DictConfig): Configuration for the model. + tokenizer (PreTrainedTokenizerBase): Tokenizer to use. + init_context (Optional[ContextManager], optional): Context manager to use for initialization. Defaults to None. + master_weights_dtype (Optional[str], optional): Master weights dtype. Defaults to None. + + Returns: + ComposerModel: _description_ + """ + if init_context is None: + init_context = contextlib.nullcontext() + + with init_context: + model = construct_from_registry( + name=name, + registry=registry.models, + pre_validation_function=ComposerModel, + post_validation_function=None, + kwargs={ + 'om_model_config': cfg, + 'tokenizer': tokenizer + }, + ) + + str_dtype_to_torch_dtype = { + 'f16': torch.float16, + 'float16': torch.float16, + 'bf16': torch.bfloat16, + 'bfloat16': torch.bfloat16, + } + + if master_weights_dtype is not None: + if master_weights_dtype not in str_dtype_to_torch_dtype: + raise ValueError( + f'Invalid master_weights_dtype: {master_weights_dtype}. ' + + f'Valid options are: {list(str_dtype_to_torch_dtype.keys())}.') + dtype = str_dtype_to_torch_dtype[master_weights_dtype] + model = model.to(dtype=dtype) + + return model + + def build_callback( name: str, kwargs: Optional[Dict[str, Any]] = None, diff --git a/llmfoundry/utils/warnings.py b/llmfoundry/utils/warnings.py index 892d8d5a11..6c9106b2e7 100644 --- a/llmfoundry/utils/warnings.py +++ b/llmfoundry/utils/warnings.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools import warnings -from typing import Any, Callable, TypeVar, cast +from typing import Any, Callable, Type, TypeVar, cast __all__ = [ 'VersionedDeprecationWarning', @@ -51,7 +51,7 @@ def __init__(self, feature_name: str) -> None: # Decorator version of experimental warning -def experimental(feature_name: str) -> Callable[[F], F]: +def experimental_function(feature_name: str) -> Callable[[F], F]: """Decorator to mark a function as experimental. The message displayed will be {feature_name} is experimental and may change with future versions. @@ -73,3 +73,19 @@ def wrapper(*args: Any, **kwargs: Any): return cast(F, wrapper) return decorator + + +def experimental_class(feature_name: str) -> Callable[[Type], Type]: + """Class decorator to mark a class as experimental.""" + + def class_decorator(cls: Type): + original_init = cls.__init__ + + def new_init(self: Any, *args: Any, **kwargs: Any): + warnings.warn(ExperimentalWarning(feature_name)) + original_init(self, *args, **kwargs) + + cls.__init__ = new_init + return cls + + return class_decorator diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 24394ad906..961b50e254 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -12,22 +12,19 @@ import pandas as pd import torch from composer.loggers.logger_destination import LoggerDestination -from composer.models.base import ComposerModel from composer.trainer import Trainer from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from rich.traceback import install -from transformers import PreTrainedTokenizerBase from llmfoundry.utils import (find_mosaicml_logger, log_eval_analytics, maybe_create_mosaicml_logger) install() -from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, - build_evaluators, build_logger, - build_tokenizer) + build_composer_model, build_evaluators, + build_logger, build_tokenizer) from llmfoundry.utils.config_utils import (log_config, pop_config, process_init_device) from llmfoundry.utils.registry_utils import import_file @@ -35,30 +32,6 @@ log = logging.getLogger(__name__) -def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - fsdp_config: Optional[Dict], num_retries: int) -> ComposerModel: - init_context = process_init_device(model_cfg, fsdp_config) - - retries = 0 - composer_model = None - with init_context: - while retries < num_retries and composer_model is None: - try: - composer_model = COMPOSER_MODEL_REGISTRY[model_cfg.name]( - model_cfg, tokenizer) - except Exception as e: - retries += 1 - if retries >= num_retries: - raise e - else: - log.info( - f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' - ) - - assert composer_model is not None - return composer_model - - def evaluate_model( model_cfg: DictConfig, dist_timeout: Union[float, int], @@ -70,7 +43,6 @@ def evaluate_model( eval_gauntlet_config: Optional[Union[str, DictConfig]], eval_loader_config: Optional[Union[DictConfig, ListConfig]], fsdp_config: Optional[Dict], - num_retries: int, loggers: List[LoggerDestination], python_log_level: Optional[str], precision: str, @@ -118,8 +90,14 @@ def evaluate_model( 'The FSDP config block is not supported when loading ' + 'Hugging Face models in 8bit.') - composer_model = load_model(model_cfg.model, tokenizer, fsdp_config, - num_retries) + init_context = process_init_device(model_cfg.model, fsdp_config) + + composer_model = build_composer_model( + name=model_cfg.model.name, + cfg=model_cfg.model, + tokenizer=tokenizer, + init_context=init_context, + ) # Now add the eval metrics if eval_loader_config is not None: @@ -236,10 +214,6 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: 'run_name', must_exist=False, default_value=default_run_name) - num_retries: int = pop_config(cfg, - 'num_retries', - must_exist=False, - default_value=3) loggers_cfg: Dict[str, Any] = pop_config(cfg, 'loggers', must_exist=False, @@ -318,7 +292,6 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: eval_gauntlet_config=eval_gauntlet_config, eval_loader_config=eval_loader_config, fsdp_config=fsdp_config, - num_retries=num_retries, loggers=loggers, python_log_level=python_log_level, precision=precision, diff --git a/scripts/inference/benchmarking/benchmark.py b/scripts/inference/benchmarking/benchmark.py index d2e51bb7a5..00daf6b559 100644 --- a/scripts/inference/benchmarking/benchmark.py +++ b/scripts/inference/benchmarking/benchmark.py @@ -9,7 +9,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om -from llmfoundry import COMPOSER_MODEL_REGISTRY +from llmfoundry.utils.builders import build_composer_model, build_tokenizer def get_dtype(dtype: str): @@ -58,8 +58,17 @@ def main(config: DictConfig): }, } - composer_model = COMPOSER_MODEL_REGISTRY[config.model.name]( - config.model, config.tokenizer) + tokenizer_name = config.tokenizer['name'] + tokenizer_kwargs = config.tokenizer.get('kwargs', {}) + tokenizer = build_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_kwargs=tokenizer_kwargs, + ) + composer_model = build_composer_model( + name=config.model.name, + cfg=config.model, + tokenizer=tokenizer, + ) model = composer_model.model model.eval() diff --git a/scripts/train/train.py b/scripts/train/train.py index e1d91c3cdb..93491452dd 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -25,16 +25,13 @@ install() -from transformers import PreTrainedTokenizerBase - -from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.callbacks import AsyncEval from llmfoundry.data.dataloader import build_dataloader from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_algorithm, build_callback, - build_evaluators, build_logger, - build_optimizer, build_scheduler, - build_tokenizer) + build_composer_model, build_evaluators, + build_logger, build_optimizer, + build_scheduler, build_tokenizer) from llmfoundry.utils.config_utils import (log_config, pop_config, process_init_device, update_batch_size_info) @@ -139,17 +136,6 @@ def validate_config(cfg: DictConfig): ) -def build_composer_model(model_cfg: DictConfig, - tokenizer: PreTrainedTokenizerBase): - warnings.filterwarnings( - action='ignore', - message='Torchmetrics v0.9 introduced a new argument class property') - if model_cfg.name not in COMPOSER_MODEL_REGISTRY: - raise ValueError( - f'Not sure how to build model with name={model_cfg.name}') - return COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg, tokenizer) - - def main(cfg: DictConfig) -> Trainer: # Run user provided code if specified code_paths = pop_config(cfg, @@ -548,13 +534,13 @@ def main(cfg: DictConfig) -> Trainer: # Build Model log.info('Initializing model...') - with init_context: - model = build_composer_model(model_config, tokenizer) - - if model_config.get('master_weights_dtype') in ('bf16', 'bfloat16'): - model = model.to(dtype=torch.bfloat16) - elif model_config.get('master_weights_dtype') in ('f16', 'float16'): - model = model.to(dtype=torch.float16) + model = build_composer_model( + name=model_config.name, + cfg=model_config, + tokenizer=tokenizer, + init_context=init_context, + master_weights_dtype=model_config.get('master_weights_dtype', None), + ) # Log number of parameters n_params = sum(p.numel() for p in model.parameters()) diff --git a/tests/a_scripts/eval/test_eval.py b/tests/a_scripts/eval/test_eval.py index c9dfb88732..63c4ea8261 100644 --- a/tests/a_scripts/eval/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -11,8 +11,8 @@ from composer import Trainer from composer.loggers import InMemoryLogger -from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.utils import build_tokenizer +from llmfoundry.utils.builders import build_composer_model from scripts.eval.eval import main # noqa: E402 from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall, gpt_tiny_cfg) @@ -47,8 +47,12 @@ def mock_saved_model_path(eval_cfg: Union[om.ListConfig, om.DictConfig]): tokenizer = build_tokenizer(model_cfg.tokenizer.name, model_cfg.tokenizer.get('kwargs', {})) # build model - model = COMPOSER_MODEL_REGISTRY[model_cfg.model.name](model_cfg.model, - tokenizer) + model = build_composer_model( + name=model_cfg.model.name, + cfg=model_cfg.model, + tokenizer=tokenizer, + ) + # create mocked save checkpoint trainer = Trainer(model=model, device=device) saved_model_path = os.path.join(os.getcwd(), 'test_model.pt') diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index f1cd94ea3c..0119f8edd2 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -21,12 +21,12 @@ from torch.utils.data import DataLoader from transformers import PreTrainedModel, PreTrainedTokenizerBase -from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename from llmfoundry.data.finetuning import build_finetuning_dataloader from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM -from llmfoundry.utils.builders import build_optimizer, build_tokenizer +from llmfoundry.utils.builders import (build_composer_model, build_optimizer, + build_tokenizer) from scripts.inference.convert_composer_to_hf import convert_composer_to_hf from tests.data_utils import make_tiny_ft_dataset @@ -543,8 +543,11 @@ def test_huggingface_conversion_callback( device_batch_size, ) - original_model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, - tokenizer) + original_model = build_composer_model( + name=model_cfg['name'], + cfg=model_cfg, + tokenizer=tokenizer, + ) optimizer_config = { 'name': 'decoupled_adamw', @@ -742,8 +745,11 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool, om_cfg['model']['init_device'] = 'cpu' tokenizer = transformers.AutoTokenizer.from_pretrained( om_cfg.tokenizer.name, use_auth_token=model == 'llama2') - original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name]( - om_cfg['model'], tokenizer) + original_model = build_composer_model( + name=om_cfg['model'].name, + cfg=om_cfg['model'], + tokenizer=tokenizer, + ) trainer = Trainer(model=original_model, device='cpu') trainer.save_checkpoint(os.path.join(tmp_path, 'checkpoint.pt')) @@ -842,8 +848,11 @@ def test_convert_and_generate_meta(tie_word_embeddings: str, om_cfg['tie_word_embeddings'] = tie_word_embeddings tokenizer = transformers.AutoTokenizer.from_pretrained( om_cfg.tokenizer.name) - original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name]( - om_cfg['model'], tokenizer) + original_model = build_composer_model( + name=om_cfg['model'].name, + cfg=om_cfg['model'], + tokenizer=tokenizer, + ) trainer = Trainer(model=original_model, device='cpu') trainer.save_checkpoint(os.path.join(tmp_path_gathered, 'checkpoint.pt')) diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 1b1ef86302..e4e6892fe3 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -8,13 +8,16 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM -from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.utils.builders import build_composer_model, build_tokenizer def _build_model(config: DictConfig, tokenizer: PreTrainedTokenizerBase): - model = COMPOSER_MODEL_REGISTRY[config.name](config, tokenizer) + model = build_composer_model( + name=config.name, + cfg=config, + tokenizer=tokenizer, + ) return model diff --git a/tests/models/hf/test_fsdp_weight_tying.py b/tests/models/hf/test_fsdp_weight_tying.py index 00172363d1..6e7838e7ba 100644 --- a/tests/models/hf/test_fsdp_weight_tying.py +++ b/tests/models/hf/test_fsdp_weight_tying.py @@ -9,8 +9,7 @@ from composer.models.huggingface import maybe_get_underlying_model from omegaconf import OmegaConf as om -from llmfoundry import COMPOSER_MODEL_REGISTRY -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.utils.builders import build_composer_model, build_tokenizer @pytest.mark.world_size(2) @@ -68,8 +67,11 @@ def test_fsdp_weight_tying(peft_config: Optional[dict], tmp_path: pathlib.Path, tokenizer_kwargs={'model_max_length': 32}, ) - original_model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, - tokenizer) + original_model = build_composer_model( + name=model_cfg['name'], + cfg=model_cfg, + tokenizer=tokenizer, + ) underlying_model = maybe_get_underlying_model(original_model.model) lm_head = underlying_model.lm_head if peft_config is None else underlying_model.lm_head diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py index d007850b68..d541f0a30c 100644 --- a/tests/models/hf/test_hf_config.py +++ b/tests/models/hf/test_hf_config.py @@ -13,9 +13,9 @@ from omegaconf import OmegaConf as om from transformers import AutoModelForCausalLM -from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer +from llmfoundry.utils.builders import build_composer_model def test_remote_code_false_mpt( @@ -45,8 +45,11 @@ def test_remote_code_false_mpt( with pytest.raises( ValueError, match='trust_remote_code must be set to True for MPT models.'): - _ = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model, - tokenizer) + _ = build_composer_model( + name=test_cfg.model.name, + cfg=test_cfg.model, + tokenizer=tokenizer, + ) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) @@ -132,8 +135,11 @@ def test_hf_config_override( tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model, - tokenizer) + model = build_composer_model( + name=test_cfg.model.name, + cfg=test_cfg.model, + tokenizer=tokenizer, + ) # save model tmp_dir = tempfile.TemporaryDirectory() @@ -153,8 +159,11 @@ def test_hf_config_override( }) hf_model_config.model = model_cfg - hf_model = COMPOSER_MODEL_REGISTRY[hf_model_config.model.name]( - hf_model_config.model, tokenizer=tokenizer) + hf_model = build_composer_model( + name=hf_model_config.model.name, + cfg=hf_model_config.model, + tokenizer=tokenizer, + ) for k, v in hf_model_config.model.config_overrides.items(): if isinstance(v, Mapping): @@ -185,7 +194,11 @@ def test_rope_scaling_override(): } model_cfg = om.create(model_cfg) - model = COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg, tokenizer=None) + model = build_composer_model( + name=model_cfg.name, + cfg=model_cfg, + tokenizer=None, # type: ignore + ) # This would error if the config isn't parsed into a proper dictionary model.get_metadata() assert model.config.rope_scaling == {'type': 'dynamic', 'factor': 0.5} diff --git a/tests/models/hf/test_hf_peft_wrapping.py b/tests/models/hf/test_hf_peft_wrapping.py index 11d183c9d6..d8bea33dd4 100644 --- a/tests/models/hf/test_hf_peft_wrapping.py +++ b/tests/models/hf/test_hf_peft_wrapping.py @@ -12,9 +12,8 @@ from omegaconf import OmegaConf as om from peft import LoraConfig, get_peft_model -from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.utils.builders import build_composer_model, build_tokenizer def test_peft_wraps(): @@ -84,8 +83,11 @@ def test_lora_mixed_init(peft_config: Optional[dict], tmp_path: pathlib.Path, tokenizer_kwargs={'model_max_length': 32}, ) - original_model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, - tokenizer) + original_model = build_composer_model( + name=model_cfg['name'], + cfg=model_cfg, + tokenizer=tokenizer, + ) trainer = Trainer( model=original_model, diff --git a/tests/models/hf/test_hf_v_mpt.py b/tests/models/hf/test_hf_v_mpt.py index 1319934506..b44c8d14c2 100644 --- a/tests/models/hf/test_hf_v_mpt.py +++ b/tests/models/hf/test_hf_v_mpt.py @@ -8,7 +8,7 @@ from composer.utils import reproducibility from omegaconf import OmegaConf as om -from llmfoundry import COMPOSER_MODEL_REGISTRY +from llmfoundry.utils.builders import build_composer_model, build_tokenizer @pytest.mark.gpu @@ -35,7 +35,7 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, conf_path = 'scripts/train/yamls/pretrain/mpt-125m.yaml' # set cfg path batch_size = 2 # set batch size - device = 'cuda' # set decive + device = 'cuda' # set device # get hf gpt2 cfg hf_cfg = om.create({ @@ -57,8 +57,17 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, # get hf gpt2 model print(hf_cfg) - hf_model = COMPOSER_MODEL_REGISTRY[hf_cfg.model.name]( - hf_cfg.model, hf_cfg.tokenizer).to(device) + tokenizer_name = hf_cfg.tokenizer['name'] + tokenizer_kwargs = hf_cfg.tokenizer.get('kwargs', {}) + tokenizer = build_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_kwargs=tokenizer_kwargs, + ) + hf_model = build_composer_model( + name=hf_cfg.model.name, + cfg=hf_cfg.model, + tokenizer=tokenizer, + ).to(device) hf_n_params = sum(p.numel() for p in hf_model.parameters()) hf_model.model.config.embd_pdrop = dropout @@ -72,7 +81,7 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, # in mosaic gpt, attn_dropout is integrated into the FlashMHA kernel # and will therefore generate different drop idx when compared to nn.Dropout - # reguradless of if rng is seeded + # regardless of if rng is seeded # attn_dropout must be set to 0 for numerical comparisons. hf_model.model.config.attn_pdrop = 0.0 for b in hf_model.model.transformer.h: @@ -96,7 +105,7 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, model_cfg.emb_pdrop = hf_model.model.config.embd_pdrop # attn_dropout is integrated into the FlashMHA kernel # given this, it will generate different drop idx when compared to nn.Dropout - # reguradless of if rng is seeded. + # regardless of if rng is seeded. model_cfg.attn_pdrop = hf_model.model.config.attn_pdrop model_cfg.n_layers = hf_model.model.config.n_layer model_cfg.d_model = hf_model.model.config.n_embd @@ -106,8 +115,11 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, print('Initializing model...') print(model_cfg) - model = COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg, - cfg.tokenizer).to(device) + model = build_composer_model( + name=model_cfg.name, + cfg=model_cfg, + tokenizer=tokenizer, + ).to(device) n_params = sum(p.numel() for p in model.parameters()) if alibi: diff --git a/tests/models/inference_api_wrapper/test_fmapi.py b/tests/models/inference_api_wrapper/test_fmapi.py index 32b7764400..bde2c90d36 100644 --- a/tests/models/inference_api_wrapper/test_fmapi.py +++ b/tests/models/inference_api_wrapper/test_fmapi.py @@ -95,10 +95,10 @@ def test_casual_fmapi_wrapper(tmp_path: str): tokenizer = transformers.AutoTokenizer.from_pretrained( 'mosaicml/mpt-7b-8k-instruct') - model = FMAPICasualLMEvalWrapper(model_cfg={ + model = FMAPICasualLMEvalWrapper(om_model_config=DictConfig({ 'local': True, 'name': 'mosaicml/mpt-7b-8k-instruct' - }, + }), tokenizer=tokenizer) with patch.object(model, 'client') as mock: mock.completions.create = mock_create @@ -129,10 +129,10 @@ def test_chat_fmapi_wrapper(tmp_path: str): tokenizer = transformers.AutoTokenizer.from_pretrained( 'mosaicml/mpt-7b-8k-instruct') - chatmodel = FMAPIChatAPIEvalWrapper(model_cfg={ + chatmodel = FMAPIChatAPIEvalWrapper(om_model_config=DictConfig({ 'local': True, 'name': 'mosaicml/mpt-7b-8k-instruct' - }, + }), tokenizer=tokenizer) with patch.object(chatmodel, 'client') as mock: diff --git a/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py b/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py index a125203e19..7ecb61aa43 100644 --- a/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py +++ b/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py @@ -99,7 +99,8 @@ def test_openai_api_eval_wrapper(tmp_path: str, openai_api_key_env_var: str): model_name = 'davinci' tokenizer = TiktokenTokenizerWrapper(model_name=model_name, pad_token='<|endoftext|>') - model = OpenAICausalLMEvalWrapper(model_cfg={'version': model_name}, + model = OpenAICausalLMEvalWrapper(om_model_config=DictConfig( + {'version': model_name}), tokenizer=tokenizer) with patch.object(model, 'client') as mock: mock.completions.create = mock_create @@ -129,7 +130,8 @@ def test_chat_api_eval_wrapper(tmp_path: str, openai_api_key_env_var: str): model_name = 'gpt-3.5-turbo' tokenizer = TiktokenTokenizerWrapper(model_name=model_name, pad_token='<|endoftext|>') - chatmodel = OpenAIChatAPIEvalWrapper(model_cfg={'version': model_name}, + chatmodel = OpenAIChatAPIEvalWrapper(om_model_config=DictConfig( + {'version': model_name}), tokenizer=tokenizer) with patch.object(chatmodel, 'client') as mock: mock.chat.completions.create.return_value = MockChatCompletion( diff --git a/tests/models/layers/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py index 818136e8fa..dfd3b17f96 100644 --- a/tests/models/layers/test_huggingface_flash.py +++ b/tests/models/layers/test_huggingface_flash.py @@ -13,12 +13,11 @@ from omegaconf import OmegaConf as om from transformers.models.llama.modeling_llama import LlamaAttention -from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.models.hf.hf_fsdp import rgetattr from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.layers.llama_attention_monkeypatch import ( llama_attention_patch_torch, llama_attention_patch_triton) -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.utils.builders import build_composer_model, build_tokenizer @pytest.mark.parametrize('patch_fn_name', ['torch', 'triton']) @@ -172,7 +171,11 @@ def test_flash2(model_name: str, use_flash_attention_2: bool, init_device: str): ) and use_flash_attention_2 else contextlib.nullcontext() with error_context: - model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer) + model = build_composer_model( + name=model_cfg['name'], + cfg=model_cfg, + tokenizer=tokenizer, + ) # check that it actually used flash attention 2 assert model.model.config._attn_implementation == ( diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 4765c4003b..79a5e4f98f 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -25,7 +25,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.bloom.modeling_bloom import build_alibi_tensor -from llmfoundry import COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM +from llmfoundry import ComposerHFCausalLM from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias from llmfoundry.models.layers.attention import (check_alibi_support, @@ -33,6 +33,7 @@ from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer +from llmfoundry.utils.builders import build_composer_model def get_config( @@ -83,8 +84,12 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) - model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model, - tokenizer) + model = build_composer_model( + name=test_cfg.model.name, + cfg=test_cfg.model, + tokenizer=tokenizer, + ) + # Optimizer assert test_cfg.optimizer.name == 'decoupled_adamw' optimizer = DecoupledAdamW(model.parameters(), @@ -159,6 +164,7 @@ def test_full_forward_and_backward(batch_size: int = 2): original_params = next(model.parameters()).clone().data outputs = model(batch) loss = model.loss(outputs, batch) + assert isinstance(loss, torch.Tensor) loss.backward() optimizer.step() updated_params = next(model.parameters()).clone().data @@ -175,6 +181,7 @@ def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2): original_params = next(model.parameters()).clone().data outputs = model(batch) loss = model.loss(outputs, batch) + assert isinstance(loss, torch.Tensor) loss.backward() optimizer.step() updated_params = next(model.parameters()).clone().data @@ -273,8 +280,11 @@ def test_full_forward_and_backward_gpt2_small(prefixlm: bool, tokenizer = build_tokenizer(neo_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) - model = COMPOSER_MODEL_REGISTRY[neo_cfg.model.name](neo_cfg.model, - tokenizer).to(device) + model = build_composer_model( + name=neo_cfg.model.name, + cfg=neo_cfg.model, + tokenizer=tokenizer, + ).to(device) assert isinstance(model.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)) @@ -296,6 +306,7 @@ def test_full_forward_and_backward_gpt2_small(prefixlm: bool, original_params = next(model.parameters()).clone().data outputs = model(batch) loss = model.loss(outputs, batch) + assert isinstance(loss, torch.Tensor) loss.backward() optimizer.step() updated_params = next(model.parameters()).clone().data @@ -318,8 +329,11 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): tokenizer = build_tokenizer(t5_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) - model = COMPOSER_MODEL_REGISTRY[t5_cfg.model.name](t5_cfg.model, - tokenizer).to(device) + model = build_composer_model( + name=t5_cfg.model.name, + cfg=t5_cfg.model, + tokenizer=tokenizer, + ).to(device) assert isinstance(model.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)) @@ -340,6 +354,7 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): original_params = next(model.parameters()).clone().data outputs = model(batch) loss = model.loss(outputs, batch) + assert isinstance(loss, torch.Tensor) loss.backward() optimizer.step() updated_params = next(model.parameters()).clone().data @@ -391,8 +406,11 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str, tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) - model_1 = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model, - tokenizer) + model_1 = build_composer_model( + name=test_cfg.model.name, + cfg=test_cfg.model, + tokenizer=tokenizer, + ) model_2 = copy.deepcopy(model_1) optimizer_1 = DecoupledAdamW(model_1.parameters(), @@ -416,6 +434,8 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str, loss_1 = model_1.loss(output_1, batch) loss_2 = model_2.loss(output_2, batch) + assert isinstance(loss_1, torch.Tensor) + assert isinstance(loss_2, torch.Tensor) assert loss_1 == loss_2 loss_1.backward() loss_2.backward() @@ -457,8 +477,11 @@ def test_loss_fn(): tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) - model_1 = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model, - tokenizer) + model_1 = build_composer_model( + name=test_cfg.model.name, + cfg=test_cfg.model, + tokenizer=tokenizer, + ) model_2 = copy.deepcopy(model_1) model_1.to(test_cfg.device) @@ -487,6 +510,8 @@ def test_loss_fn(): loss_1 = model_1.loss(output_1, batch) loss_2 = model_2.loss(output_2, batch) + assert isinstance(loss_1, torch.Tensor) + assert isinstance(loss_2, torch.Tensor) assert loss_1.allclose(loss_2, rtol=1e-3, atol=1e-3), f'differed at step {i}' loss_1.backward() @@ -2023,6 +2048,7 @@ def test_hf_init(tmp_path: pathlib.Path, with torch.autocast('cuda', dtype=torch.bfloat16, enabled=True): outputs = model(batch) loss = model.loss(outputs, batch) + assert isinstance(loss, torch.Tensor) loss.backward() optimizer.step() diff --git a/tests/test_registry.py b/tests/test_registry.py index cf03cc18c9..30f6e0e38f 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -29,6 +29,7 @@ def test_expected_registries_exist(): 'callbacks_with_config', 'dataloaders', 'metrics', + 'models', } assert existing_registries == expected_registry_names