diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index aa3beda513..a219b98ddb 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -8,9 +8,10 @@ import tempfile from pathlib import Path from typing import Optional, Union +import math import torch -from composer.core import Callback, Event, State, Time +from composer.core import Callback, Event, State, Time, TimeUnit from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader @@ -83,6 +84,13 @@ def __init__( self.huggingface_folder_name_fstr = os.path.join( 'huggingface', huggingface_folder_name) + + if isinstance(save_interval, str): + save_interval = Time.from_timestring(save_interval) + if isinstance(save_interval, int): + save_interval = Time(save_interval, TimeUnit.EPOCH) + + self.save_interval = save_interval self.check_interval = create_interval_scheduler( save_interval, include_end_of_training=True) self.upload_to_object_store = (self.backend != '') @@ -225,7 +233,13 @@ def _save_checkpoint(self, state: State, logger: Logger): ) elapsed_duration = state.get_elapsed_duration() - if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0: + + # If the save interval is specified as 1dur, and the max duration is in epoch units + # we need a special case to identify we are on the last batch and should write the mlflow checkpoint + is_last_batch = False + if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH: + is_last_batch = int(state.timestamp.batch) % math.ceil(state.max_duration.value * state.dataloader_len) == 0 + if self.mlflow_registered_model_name is not None and ((elapsed_duration is not None and elapsed_duration >= 1.0) or is_last_batch): components = {'model': new_model_instance} if original_tokenizer is not None: components['tokenizer'] = original_tokenizer diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index fcb2cc3a7e..71c35f3723 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -251,25 +251,29 @@ def test_callback_inits_with_defaults(): @pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) @pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) @pytest.mark.parametrize('log_to_mlflow', [True, False]) +@pytest.mark.parametrize('hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)]) def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, fsdp_state_dict_type: Optional[str], - log_to_mlflow: bool): + log_to_mlflow: bool, + hf_save_interval: str, + save_interval: str, + max_duration: str, + expected_hf_checkpoints: int, + expected_normal_checkpoints: int): delete_transformers_cache() dist.initialize_dist(get_device('gpu')) max_seq_len = 16 - save_interval_batches = 2 - huggingface_save_interval_batches = 3 device_batch_size = 1 dataset_size = 14 - max_duration_batches = 7 precision_str = 'bfloat16' precision = torch.bfloat16 + batches_per_epoch = math.ceil(dataset_size / (device_batch_size*2)) checkpointer_callback = HuggingFaceCheckpointer( save_folder=os.path.join(tmp_path, 'checkpoints'), - save_interval=f'{huggingface_save_interval_batches}ba', + save_interval=hf_save_interval, precision=precision_str, mlflow_registered_model_name='dummy-registered-name' if log_to_mlflow else None, @@ -405,8 +409,8 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None, train_dataloader=train_dataloader, save_folder=os.path.join(tmp_path, 'checkpoints'), - save_interval=f'{save_interval_batches}ba', - max_duration=f'{max_duration_batches}ba', + save_interval=save_interval, + max_duration=max_duration, callbacks=[checkpointer_callback], loggers=[mlflow_logger_mock] if log_to_mlflow else [], optimizers=optimizer, @@ -442,15 +446,13 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, name for name in os.listdir( os.path.join(tmp_path, 'checkpoints', 'huggingface')) ] - assert len(normal_checkpoints) == math.ceil(max_duration_batches / - save_interval_batches) - assert len(huggingface_checkpoints) == math.ceil( - max_duration_batches / huggingface_save_interval_batches) + assert len(normal_checkpoints) == expected_normal_checkpoints + assert len(huggingface_checkpoints) == expected_hf_checkpoints # Load the last huggingface checkpoint loaded_model = transformers.AutoModelForCausalLM.from_pretrained( os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba{max_duration_batches}'), + f'ba{batches_per_epoch}'), trust_remote_code=True, ) @@ -471,7 +473,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, loaded_tokenizer = transformers.AutoTokenizer.from_pretrained( os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba{max_duration_batches}'), + f'ba{batches_per_epoch}'), trust_remote_code=True, )