diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 492816ea07..aa3beda513 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib -import json +import copy import logging import os import tempfile @@ -10,14 +10,14 @@ from typing import Optional, Union import torch -from composer.callbacks.utils import create_interval_scheduler from composer.core import Callback, Event, State, Time from composer.core.state import fsdp_state_dict_type_context -from composer.loggers import Logger +from composer.loggers import Logger, MLFlowLogger from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader from composer.models import HuggingFaceModel from composer.utils import dist, format_name_with_dist_and_time, parse_uri -from transformers import PreTrainedTokenizerBase +from composer.utils.misc import create_interval_scheduler +from transformers import PreTrainedModel, PreTrainedTokenizerBase from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils.huggingface_hub_utils import \ @@ -39,6 +39,11 @@ class HuggingFaceCheckpointer(Callback): huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``. precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``. overwrite (bool): Whether to overwrite previous checkpoints. + mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not + be registered. Default is ``None``. + mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call. + Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and + ``{'task': 'llm/v1/completions'}`` respectively. """ def __init__( @@ -48,6 +53,8 @@ def __init__( huggingface_folder_name: str = 'ba{batch}', precision: str = 'float32', overwrite: bool = False, + mlflow_registered_model_name: Optional[str] = None, + mlflow_logging_config: Optional[dict] = None, ): self.backend, self.bucket_name, self.save_dir_format_str = parse_uri( save_folder) @@ -58,6 +65,22 @@ def __init__( 'float16': torch.float16, 'bfloat16': torch.bfloat16, }[precision] + + # mlflow config setup + self.mlflow_registered_model_name = mlflow_registered_model_name + if mlflow_logging_config is None: + mlflow_logging_config = {} + if self.mlflow_registered_model_name is not None: + # Both the metadata and the task are needed in order for mlflow + # and databricks optimized model serving to work + if 'metadata' not in mlflow_logging_config: + mlflow_logging_config['metadata'] = { + 'task': 'llm/v1/completions' + } + if 'task' not in mlflow_logging_config: + mlflow_logging_config['task'] = 'text-generation' + self.mlflow_logging_config = mlflow_logging_config + self.huggingface_folder_name_fstr = os.path.join( 'huggingface', huggingface_folder_name) self.check_interval = create_interval_scheduler( @@ -71,6 +94,7 @@ def __init__( self.remote_ud = None self.last_checkpoint_batch: Optional[Time] = None + self.mlflow_loggers = [] def run_event(self, event: Event, state: State, logger: Logger) -> None: # The interval scheduler handles only returning True for the appropriate events @@ -87,6 +111,23 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self.remote_ud.init(state, logger) state.callbacks.append(self.remote_ud) + if self.mlflow_registered_model_name is not None: + self.mlflow_loggers = [ + logger_destination + for logger_destination in logger.destinations + if isinstance(logger_destination, MLFlowLogger) + ] + if len(self.mlflow_loggers) == 0: + raise ValueError( + f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. ' + + + 'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.' + ) + + import mlflow + mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( + '5GB') + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -99,8 +140,6 @@ def _save_checkpoint(self, state: State, logger: Logger): MPTConfig.register_for_auto_class() MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM') - assert isinstance(state.model, HuggingFaceModel) - save_dir = format_name_with_dist_and_time( str( Path(self.save_dir_format_str) / @@ -114,9 +153,29 @@ def _save_checkpoint(self, state: State, logger: Logger): assert isinstance(temp_save_dir, str) # pyright doesn't know about enter_result - with fsdp_state_dict_type_context(state.model.model, - state_dict_type='full'): - state_dict = state.model.model.state_dict() + log.debug('Gathering state dict') + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + if state.is_model_ddp: + original_model: PreTrainedModel = state.model.module.model + state_dict_model = state.model.module.model + original_tokenizer = state.model.module.tokenizer + elif isinstance(state.model.model, FSDP): + original_model: PreTrainedModel = state.model.model.module + state_dict_model = state.model.model + original_tokenizer = state.model.tokenizer + else: + original_model: PreTrainedModel = state.model.model + state_dict_model = state.model.model + original_tokenizer = state.model.tokenizer + + state_dict_context = fsdp_state_dict_type_context( + original_model, state_dict_type='full') if ( + (not state.is_model_ddp) and isinstance( + state_dict_model, FSDP)) else contextlib.nullcontext() + + with state_dict_context: + state_dict = state_dict_model.state_dict() # convert the state dict to the requested precision for k, v in state_dict.items(): @@ -124,34 +183,35 @@ def _save_checkpoint(self, state: State, logger: Logger): state_dict[k] = v.to(dtype=self.dtype) if dist.get_global_rank() == 0: - # We raise above if the model is not a HuggingFaceModel, so this assert is safe - assert hasattr(state.model.model, 'save_pretrained') - state.model.model.save_pretrained(temp_save_dir, - state_dict=state_dict) - - if state.model.tokenizer is not None: - assert isinstance(state.model.tokenizer, + log.debug('Saving Hugging Face checkpoint to disk') + + copied_config = copy.deepcopy(original_model.config) + if copied_config.model_type == 'mpt': + copied_config.attn_config['attn_impl'] = 'torch' + copied_config.init_device = 'cpu' + + # TODO: after torch 2.1, we can load a state dict into a meta model + # and skip the extra model init + log.debug(f'Creating new model instance') + new_model_instance = type(original_model)(copied_config) + new_model_instance.to(dtype=self.dtype) + new_model_instance.load_state_dict(state_dict) + del state_dict + + log.debug('Saving Hugging Face checkpoint to disk') + new_model_instance.save_pretrained(temp_save_dir) + if original_tokenizer is not None: + assert isinstance(original_tokenizer, PreTrainedTokenizerBase) - state.model.tokenizer.save_pretrained(temp_save_dir) + original_tokenizer.save_pretrained(temp_save_dir) # Only need to edit files for MPT because it has custom code - if state.model.model.config.model_type == 'mpt': + if original_model.config.model_type == 'mpt': + log.debug('Editing MPT files for HuggingFace compatibility') edit_files_for_hf_compatibility(temp_save_dir) - with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f: - edited_config = json.load(f) - - if state.model.model.config.model_type == 'mpt': - edited_config['attn_config']['attn_impl'] = 'torch' - edited_config['init_device'] = 'cpu' - - edited_config['torch_dtype'] = self.precision - with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f: - json.dump(edited_config, f, indent=4) - if self.upload_to_object_store: assert self.remote_ud is not None - # TODO change to log after other pr log.info( f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}' ) @@ -164,4 +224,31 @@ def _save_checkpoint(self, state: State, logger: Logger): overwrite=self.overwrite, ) - dist.barrier() + elapsed_duration = state.get_elapsed_duration() + if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0: + components = {'model': new_model_instance} + if original_tokenizer is not None: + components['tokenizer'] = original_tokenizer + + log.debug('Logging Hugging Face model to MLFlow') + for i, mlflow_logger in enumerate(self.mlflow_loggers): + log.debug( + f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}' + ) + local_save_path = str( + Path(temp_save_dir) / f'mlflow_save_{i}') + + # TODO: Remove after mlflow fixes the bug that makes this necessary + import mlflow + mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' + mlflow_logger.save_model( + flavor='transformers', + transformers_model=components, + path=local_save_path, + **self.mlflow_logging_config, + ) + mlflow_logger.register_model( + model_uri=local_save_path, + name=self.mlflow_registered_model_name, + await_registration_for=3600, + ) diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index d685d0077d..bc41945076 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -10,13 +10,15 @@ import numpy as np import torch +from composer.core.data_spec import DataSpec from omegaconf import DictConfig from omegaconf import OmegaConf as om from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase from llmfoundry.data.packing import BinPackWrapper -from llmfoundry.data.text_data import StreamingTextDataset +from llmfoundry.data.text_data import (StreamingTextDataset, + get_tokens_per_batch_func) from llmfoundry.models import utils __all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader'] @@ -353,7 +355,7 @@ def build_text_denoising_dataloader( cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, -) -> DataLoader[Dict]: +) -> DataSpec: """Constructor function for a Mixture of Denoisers dataloader. This function constructs a dataloader that can be used to train an @@ -506,7 +508,7 @@ def build_text_denoising_dataloader( 'but cfg.dataset.packing_ratio has not been set. Please set ' +\ 'the latter to turn on packing or remove the former from the config.') - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=device_batch_size, @@ -518,6 +520,12 @@ def build_text_denoising_dataloader( timeout=cfg.get('timeout', 0), ) + token_counting_func = get_tokens_per_batch_func( + pad_token_id=tokenizer.pad_token_id, + decoder_only=cfg.mixture_of_denoisers.decoder_only_format) + + return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) + def noise_token_sequence( example: Union[torch.Tensor, Mapping[str, Any]], @@ -869,7 +877,9 @@ def _format_tokens_for_decoder_only( tokenizer = build_tokenizer(tokenizer_name=tokenizer_name, tokenizer_kwargs=tokenizer_kwargs) - loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size) + loader = build_text_denoising_dataloader(cfg, tokenizer, + device_batch_size).dataloader + assert isinstance(loader, DataLoader) assert isinstance(loader.dataset, StreamingTextDataset) print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n') diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index ebb7991dde..2dde563ac6 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -6,6 +6,7 @@ import datasets as hf_datasets import torch +from composer.core.data_spec import DataSpec from composer.utils import dist, get_file, parse_uri from omegaconf import DictConfig from torch.utils.data import DataLoader @@ -14,6 +15,7 @@ from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator from llmfoundry.data.finetuning.tasks import dataset_constructor from llmfoundry.data.packing import BinPackWrapper +from llmfoundry.data.text_data import get_tokens_per_batch_func log = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: int) -> DataLoader: + device_batch_size: int) -> DataSpec: """Builds a finetuning dataloader for training or evaluating. The underlying dataset can be built through one of two code paths: @@ -143,7 +145,7 @@ def build_finetuning_dataloader(cfg: DictConfig, collate_fn, dataloader_batch_size = _build_collate_fn( cfg.dataset, tokenizer, device_batch_size) - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, @@ -193,7 +195,7 @@ def build_finetuning_dataloader(cfg: DictConfig, ) assert dataset is not None - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, @@ -208,6 +210,11 @@ def build_finetuning_dataloader(cfg: DictConfig, timeout=cfg.get('timeout', 0), ) + token_counting_func = get_tokens_per_batch_func( + pad_token_id=tokenizer.pad_token_id) + + return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) + def _validate_config(dataset_cfg: DictConfig) -> None: """Validates the dataset configuration. @@ -442,7 +449,8 @@ def _build_collate_fn( tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) device_batch_size = 2 - dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) + dataloader = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader packing = cfg.dataset.get('packing_ratio') is not None diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index d0a73be801..1532de276e 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -377,7 +377,7 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, dataloader_cfg.dataset.packing_ratio = None dataloader_cfg.dataset.max_leftovers_to_keep = None train_dataloader = build_dataloader(dataloader_cfg, tokenizer, - max(raw_batch_sizes) * 100) + max(raw_batch_sizes) * 100).dataloader # Get a bunch of raw examples big_batch = next(iter(train_dataloader)) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index afdd243adf..93af2f63ed 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -11,6 +11,8 @@ import numpy as np import torch import transformers +from composer.core.data_spec import DataSpec +from composer.core.types import Batch from omegaconf import DictConfig from omegaconf import OmegaConf as om from streaming import Stream, StreamingDataset @@ -237,7 +239,7 @@ def build_text_dataloader( cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, -) -> DataLoader: +) -> DataSpec: assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}' if cfg.dataset.get('group_method', None) is not None: raise NotImplementedError( @@ -281,7 +283,7 @@ def build_text_dataloader( eos_token_id=eos_token_id, bos_token_id=bos_token_id) - return DataLoader( + dl = DataLoader( dataset, collate_fn=collate_fn, batch_size=device_batch_size, @@ -293,6 +295,58 @@ def build_text_dataloader( timeout=cfg.get('timeout', 0), ) + # If we pretokenized, we may not have padding, in which case the + # tokenizer may not have a pad_token_id. In this case, we can + # just use the default token counting function. This is correct + # because we do not support training on pretokenized data with padding, + # and if tokenizing on the fly, we require that the tokenizer has a pad token. + token_counting_func = None + if tokenizer.pad_token_id is not None: + token_counting_func = get_tokens_per_batch_func( + pad_token_id=tokenizer.pad_token_id) + + return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) + + +def get_tokens_per_batch_func(pad_token_id: int, + decoder_only: bool = True + ) -> Callable[[Batch], int]: + """Returns a callable that counts the number of tokens in a batch. + + Args: + pad_token_id (int): The id of the padding token. + decoder_only (bool, optional): Whether to expect the batch to just contain ``input_ids`` (decoder only) + or to also contain ``decoder_input_ids`` (encoder decoder). Defaults to ``True``. + + Returns: + Callable[[Batch], int]: A callable that counts the number of tokens in a batch. + """ + + def get_num_samples_in_batch(batch: Batch) -> int: + if not isinstance(batch, Mapping) or 'input_ids' not in batch: + raise ValueError( + 'get_tokens_per_batch_func() requires a batch with an input_ids key' + ) + + if not decoder_only and 'decoder_input_ids' not in batch: + raise ValueError( + 'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_input_ids key' + ) + + # Count number of non padding tokens in batch + input_ids_tokens = int( + torch.sum(batch['input_ids'] != pad_token_id).item()) + + # For encoder decoder models only + decoder_input_ids_tokens = 0 + if not decoder_only: + decoder_input_ids_tokens = int( + torch.sum(batch['decoder_input_ids'] != pad_token_id).item()) + + return input_ids_tokens + decoder_input_ids_tokens + + return get_num_samples_in_batch + # Helpful to test if your dataloader is working locally # Run `python data.py --local_path [local] [--remote_path remote, optional]` and verify that batches are printed out @@ -353,7 +407,8 @@ def build_text_dataloader( tokenizer_kwargs = {'model_max_length': args.max_seq_len} tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - loader = build_text_dataloader(cfg, tokenizer, device_batch_size) + loader = build_text_dataloader(cfg, tokenizer, device_batch_size).dataloader + assert isinstance(loader, DataLoader) assert isinstance(loader.dataset, StreamingTextDataset) tokenizer = loader.dataset.tokenizer diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py index c29f73739e..4a6d21c873 100644 --- a/llmfoundry/optim/scheduler.py +++ b/llmfoundry/optim/scheduler.py @@ -20,6 +20,9 @@ def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time], time = Time.from_timestring(time) if isinstance(t_max, str): t_max = Time.from_timestring(t_max) + + assert not isinstance(time, str) and not isinstance(t_max, str) + if time.unit != t_max.unit: raise ValueError(f'{time.unit=} does not match {t_max.unit=}.') @@ -27,6 +30,9 @@ def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time], def _raise_if_units_dur(time: Union[str, Time], name: str) -> None: if isinstance(time, str): time = Time.from_timestring(time) + + assert not isinstance(time, str) + if time.unit == TimeUnit('dur'): raise ValueError(f'{name} cannot be in units of "dur".') diff --git a/setup.py b/setup.py index a686dd0808..d0ecc66160 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.3,<0.17', + 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'transformers>=4.33,<4.34', 'mosaicml-streaming>=0.6,<0.7', diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 6495eccf65..656b6d52a6 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -3,22 +3,27 @@ import contextlib import os import pathlib +import random import shutil import sys import tempfile from argparse import Namespace from typing import Optional +from unittest.mock import MagicMock import pytest import torch +import transformers from composer.utils import dist, using_torch_2 +from omegaconf import DictConfig from omegaconf import OmegaConf as om from streaming import MDSWriter from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, - build_text_dataloader) + build_text_dataloader, + get_tokens_per_batch_func) from llmfoundry.utils.builders import build_tokenizer # Add repo root to path so we can import scripts and test it @@ -137,7 +142,7 @@ def test_correct_padding(tokenizer_name: str, test_cfg.eval_loader, tokenizer, batch_size, - ) + ).dataloader batch = next(iter(eval_loader)) assert batch['input_ids'].shape == torch.Size([batch_size, 2048]) @@ -228,7 +233,7 @@ def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool, tokenizer_kwargs={'model_max_length': max_seq_len}) loader = build_text_denoising_dataloader(cfg, tokenizer, - device_batch_size) + device_batch_size).dataloader batch_ix = 0 for batch in loader: for k in expected_keys: @@ -287,7 +292,8 @@ def test_finetuning_dataloader(decoder_only_format: bool, else: expected_keys += ['decoder_attention_mask', 'decoder_input_ids'] - loader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) + loader = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader batch_ix = 0 for batch in loader: for k in expected_keys: @@ -541,7 +547,8 @@ def test_malformed_data( match='Unable to tokenize example') with error_context: - dl = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) + dl = build_finetuning_dataloader(cfg, tokenizer, + device_batch_size).dataloader if not add_bad_data_error: # +5 because we added samples with just bos/eos in each of prompt/response @@ -552,3 +559,175 @@ def test_malformed_data( actual_num_batches += 1 assert actual_num_batches == expected_num_batches + + +@pytest.mark.parametrize('pad_token_id', [0, 100, 1000]) +@pytest.mark.parametrize('batch_size', [1, 8, 16]) +@pytest.mark.parametrize('model_max_length', [1024, 2048]) +@pytest.mark.parametrize('padding_side', ['left', 'right']) +@pytest.mark.parametrize('add_decoder_input_ids', [True, False]) +def test_token_counting_func(pad_token_id: int, batch_size: int, + model_max_length: int, padding_side: str, + add_decoder_input_ids: bool): + gptt = transformers.AutoTokenizer.from_pretrained('gpt2') + gptt.pad_token_id = pad_token_id + gptt.model_max_length = model_max_length + gptt.padding_side = padding_side + + batch_strings = [] + expected_token_count = 0 + for _ in range(batch_size): + sample_length = random.randint(1, model_max_length) + batch_strings.append(' '.join(['hello'] * sample_length)) + expected_token_count += sample_length + + batch_tokenized = gptt(batch_strings, padding=True, return_tensors='pt') + + if add_decoder_input_ids: + decoder_batch_strings = [] + decoder_expected_token_count = 0 + for _ in range(batch_size): + sample_length = random.randint(1, model_max_length) + decoder_batch_strings.append(' '.join(['hello'] * sample_length)) + decoder_expected_token_count += sample_length + expected_token_count += sample_length + batch_tokenized['decoder_input_ids'] = gptt( + decoder_batch_strings, padding=True, + return_tensors='pt')['input_ids'] + + token_counting_func = get_tokens_per_batch_func( + pad_token_id, decoder_only=not add_decoder_input_ids) + + actual_token_count = token_counting_func(batch_tokenized) + + assert actual_token_count == expected_token_count + + +@pytest.mark.parametrize( + 'dataloader_type', + ['finetuning-hf', 'finetuning-streaming', 'denoising', 'text']) +@pytest.mark.parametrize('pad_token_id', [100, None]) +@pytest.mark.parametrize('batch_size', [1, 8]) +@pytest.mark.parametrize('model_max_length', [1024]) +@pytest.mark.parametrize('padding_side', ['left']) +def test_token_counting_func_dataloader_setting( + dataloader_type: str, pad_token_id: Optional[int], batch_size: int, + model_max_length: int, padding_side: str, + monkeypatch: pytest.MonkeyPatch): + gptt = transformers.AutoTokenizer.from_pretrained('gpt2') + gptt.pad_token_id = pad_token_id + gptt.model_max_length = model_max_length + gptt.padding_side = padding_side + + batch_strings = [] + expected_token_count = 0 + for _ in range(batch_size): + sample_length = random.randint( + 1, + model_max_length) if pad_token_id is not None else model_max_length + batch_strings.append(' '.join(['hello'] * sample_length)) + expected_token_count += sample_length + + batch_tokenized = gptt(batch_strings, + padding=True if pad_token_id is not None else False, + return_tensors='pt') + + if dataloader_type == 'denoising': + batch_tokenized['decoder_input_ids'] = batch_tokenized[ + 'input_ids'].clone() + expected_token_count *= 2 + + common_args = { + 'drop_last': False, + 'num_workers': 0, + 'prefetch_factor': None if using_torch_2() else 2, + 'pin_memory': False, + 'persistent_workers': False, + 'timeout': 0 + } + + if dataloader_type == 'finetuning-hf': + cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'hf_name': 'dummy-path', + 'split': 'train', + 'max_seq_len': model_max_length, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + **common_args + }) + monkeypatch.setattr( + 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_hf', + lambda *args, **kwargs: []) + dl = build_finetuning_dataloader(cfg, gptt, batch_size) + elif dataloader_type == 'finetuning-streaming': + cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'remote': 'dummy-path', + 'local': 'dummy-path', + 'split': 'train', + 'max_seq_len': model_max_length, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + **common_args + }) + monkeypatch.setattr( + 'llmfoundry.data.finetuning.tasks.DatasetConstructor.build_from_streaming', + lambda *args, **kwargs: []) + dl = build_finetuning_dataloader(cfg, gptt, batch_size) + elif dataloader_type == 'text': + cfg = DictConfig({ + 'name': 'text', + 'dataset': { + 'local': 'dummy-path', + 'remote': 'dummy-path', + 'split': 'train', + 'max_seq_len': model_max_length, + 'shuffle': True, + 'shuffle_seed': 0, + }, + **common_args + }) + monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, **kwargs: MagicMock()) + dl = build_text_dataloader(cfg, gptt, batch_size) + elif dataloader_type == 'denoising': + cfg = DictConfig({ + 'name': 'text_denoising', + 'dataset': { + 'local': 'dummy-path', + 'remote': 'dummy-path', + 'split': 'val_xsmall', + 'shuffle': False, + 'max_seq_len': model_max_length, + 'packing_ratio': None, + 'predownload': 1000, + 'keep_zip': False, + 'num_workers': None + }, + 'mixture_of_denoisers': { + 'decoder_only_format': False, + 'span_mean_lengths_and_ratios': [[3, .15], [8, .5]], + 'sequence_mask_ratios': 0.25, + }, + **common_args + }) + monkeypatch.setattr('llmfoundry.data.denoising.StreamingTextDataset', + lambda *args, **kwargs: MagicMock()) + dl = build_text_denoising_dataloader(cfg, gptt, batch_size) + else: + raise NotImplementedError() + + cfg = om.create(cfg) + + actual_token_count = dl.get_num_tokens_in_batch(batch_tokenized) + + assert actual_token_count == expected_token_count diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index c944dcfc97..5bc3ed6d5d 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -5,8 +5,10 @@ import os import pathlib import sys +from unittest.mock import MagicMock from composer import Trainer +from composer.loggers import MLFlowLogger from composer.utils import dist, get_device from llmfoundry.callbacks import HuggingFaceCheckpointer @@ -17,7 +19,7 @@ sys.path.append(repo_dir) import shutil from argparse import Namespace -from typing import cast +from typing import Optional, cast import pytest import torch @@ -148,6 +150,23 @@ def check_hf_model_equivalence(model1: PreTrainedModel, # so we remove it expected_model_config_dict.pop('_name_or_path') new_model_config_dict.pop('_name_or_path') + + # Special case a couple of differences that correctly occur when saving MPT to huggingface format + # checkpoint + architectures_1 = expected_model_config_dict.pop('architectures', None) + architectures_2 = new_model_config_dict.pop('architectures', None) + if architectures_1 != architectures_2: + assert architectures_1 is None and architectures_2 == ['MPTForCausalLM'] + + auto_map_1 = expected_model_config_dict.pop('auto_map', None) + auto_map_2 = new_model_config_dict.pop('auto_map', None) + if auto_map_1 != auto_map_2: + assert auto_map_1 == {'AutoConfig': 'configuration_mpt.MPTConfig'} + assert auto_map_2 == { + 'AutoConfig': 'configuration_mpt.MPTConfig', + 'AutoModelForCausalLM': 'modeling_mpt.MPTForCausalLM' + } + assert expected_model_config_dict == new_model_config_dict assert all( torch.equal(p1.cpu(), p2.cpu()) @@ -183,9 +202,11 @@ def test_callback_inits_with_defaults(): @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) -@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded']) +@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) +@pytest.mark.parametrize('log_to_mlflow', [True, False]) def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, - fsdp_state_dict_type: str): + fsdp_state_dict_type: Optional[str], + log_to_mlflow: bool): delete_transformers_cache() dist.initialize_dist(get_device('gpu')) @@ -203,6 +224,8 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, save_folder=os.path.join(tmp_path, 'checkpoints'), save_interval=f'{huggingface_save_interval_batches}ba', precision=precision_str, + mlflow_registered_model_name='dummy-registered-name' + if log_to_mlflow else None, ) # get small version of each model @@ -324,20 +347,35 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, optimizer = build_optimizer(original_model, optimizer_name, optimizer_config) + mlflow_logger_mock = MagicMock(spec=MLFlowLogger) + mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} + mlflow_logger_mock.save_model = MagicMock() + mlflow_logger_mock.register_model = MagicMock() + mlflow_logger_mock.model_registry_prefix = '' trainer = Trainer( model=original_model, device='gpu', - fsdp_config=fsdp_config, + fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None, train_dataloader=train_dataloader, save_folder=os.path.join(tmp_path, 'checkpoints'), save_interval=f'{save_interval_batches}ba', max_duration=f'{max_duration_batches}ba', callbacks=[checkpointer_callback], + loggers=[mlflow_logger_mock] if log_to_mlflow else [], optimizers=optimizer, save_latest_filename=None, ) trainer.fit() + if dist.get_global_rank() == 0: + assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow + else 0) + assert mlflow_logger_mock.register_model.call_count == ( + 1 if log_to_mlflow else 0) + else: + assert mlflow_logger_mock.log_model.call_count == 0 + assert mlflow_logger_mock.register_model.call_count == 0 + # summon full params to check equivalence from torch.distributed.fsdp import FullyShardedDataParallel as FSDP with FSDP.summon_full_params(trainer.state.model, @@ -390,8 +428,10 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, trust_remote_code=True, ) - check_hf_model_equivalence(trainer.state.model.model.to(precision), - loaded_model) + check_hf_model_equivalence( + trainer.state.model.model.to(precision) if fsdp_state_dict_type + is not None else trainer.state.model.module.model.to(precision), + loaded_model) check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer) delete_transformers_cache()