From db8be5e3481fcc8c416df3f638a51ce1206d9bc8 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 10 Jan 2024 18:56:42 +0000 Subject: [PATCH] fix style etc. --- llmfoundry/callbacks/hf_checkpointer.py | 9 ++- tests/callbacks/test_hf_checkpointer.py | 92 +++++++++++++++++++++++++ tests/conftest.py | 1 + tests/fixtures/object_stores.py | 39 +++++++++++ 4 files changed, 139 insertions(+), 2 deletions(-) create mode 100644 tests/callbacks/test_hf_checkpointer.py create mode 100644 tests/fixtures/object_stores.py diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 491d510188..31c7f5ba06 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -245,11 +245,16 @@ def _save_checkpoint(self, state: State, logger: Logger): ) if self.remote_ud is not None: - log.info(f'Uploading HuggingFace formatted checkpoint') for filename in os.listdir(temp_save_dir): + remote_file_name = os.path.join(save_dir, filename) + remote_file_uri = self.remote_ud.remote_backend.get_uri( + remote_file_name) + log.info( + f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' + ) self.remote_ud.upload_file( state=state, - remote_file_name=os.path.join(save_dir, filename), + remote_file_name=remote_file_name, file_path=Path(os.path.join(temp_save_dir, filename)), overwrite=self.overwrite, diff --git a/tests/callbacks/test_hf_checkpointer.py b/tests/callbacks/test_hf_checkpointer.py new file mode 100644 index 0000000000..bab57c65e0 --- /dev/null +++ b/tests/callbacks/test_hf_checkpointer.py @@ -0,0 +1,92 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, List +from unittest.mock import patch + +from composer.core import State, Time, TimeUnit +from composer.loggers import Logger + +from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM + +dummy_s3_path = 's3://dummy/path' +dummy_oci_path = 'oci://dummypath' +dummy_gc_path = 'gs://dummy/path' +dummy_uc_path = 'dbfs://dummypath/Volumes/the_catalog/the_schema/yada_yada' + +dummy_save_interval = Time(1, TimeUnit.EPOCH) + + +def dummy_log_info(log_output: List[str]): + def _dummy_log_info(*msgs: str): + log_output.extend(msgs) + + return _dummy_log_info + + +def dummy_upload_file(*_, **__: Dict[str, Any]): + pass + + +def assert_checkpoint_saves_to_uri( + uri: str, build_tiny_hf_mpt: Callable[..., ComposerMPTCausalLM]): + uri_base = uri.split('://')[0] + model = build_tiny_hf_mpt() + + dummy_state = State(model=model, + rank_zero_seed=42, + run_name='dummy_run', + device='cpu') + dummy_logger = Logger(dummy_state) + # mock the State and Logger + logs = [] + with patch('logging.Logger.info', dummy_log_info(logs)): + my_checkpointer = HuggingFaceCheckpointer( + save_folder=uri, save_interval=dummy_save_interval) + my_checkpointer.remote_ud.upload_file = dummy_upload_file + my_checkpointer._save_checkpoint(dummy_state, dummy_logger) + + assert any([uri_base in str(log) for log in logs]) + + +def test_checkpoint_saves_to_s3( + build_tiny_hf_mpt: Callable[..., ComposerMPTCausalLM]): + assert_checkpoint_saves_to_uri(dummy_s3_path, build_tiny_hf_mpt) + + +class DummyData: + + def __init__(self, *_, **__: Any): + self.data = 'πŸͺ' + pass + + +class DummyClient: + + def __init__(self, *_, **__: Any): + pass + + def get_namespace(self, *_, **__: Any): + return DummyData() + + +def test_checkpoint_saves_to_oci( + build_tiny_hf_mpt: Callable[..., + ComposerMPTCausalLM], oci_temp_file: None): + with patch('oci.config.from_file', lambda _: {}), \ + patch('oci.object_storage.ObjectStorageClient', lambda *_, **__: DummyClient()), \ + patch('oci.object_storage.UploadManager', lambda *_, **__: None): + assert_checkpoint_saves_to_uri(dummy_oci_path, build_tiny_hf_mpt) + + +def test_checkpoint_saves_to_gc( + build_tiny_hf_mpt: Callable[..., ComposerMPTCausalLM], + gcs_account_credentials: None): + assert_checkpoint_saves_to_uri(dummy_gc_path, build_tiny_hf_mpt) + + +def test_checkpoint_saves_to_uc( + build_tiny_hf_mpt: Callable[..., ComposerMPTCausalLM], + uc_account_credentials: None): + assert_checkpoint_saves_to_uri(dummy_uc_path, build_tiny_hf_mpt) diff --git a/tests/conftest.py b/tests/conftest.py index 545dc7e38f..eff181a851 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ # Add the path of any pytest fixture files you want to make global pytest_plugins = [ + 'tests.fixtures.object_stores', 'tests.fixtures.autouse', 'tests.fixtures.models', 'tests.fixtures.data', diff --git a/tests/fixtures/object_stores.py b/tests/fixtures/object_stores.py new file mode 100644 index 0000000000..ae03add6eb --- /dev/null +++ b/tests/fixtures/object_stores.py @@ -0,0 +1,39 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import tempfile + +from pytest import fixture + + +@fixture +def gcs_account_credentials(): + """Mocked GCS Credentials for service level account.""" + os.environ['GCS_KEY'] = 'πŸ—οΈ' + os.environ['GCS_SECRET'] = '🀫' + yield + del os.environ['GCS_KEY'] + del os.environ['GCS_SECRET'] + + +@fixture +def uc_account_credentials(): + """Mocked UC Credentials for service level account.""" + os.environ['DATABRICKS_HOST'] = '⛡️' + os.environ['DATABRICKS_TOKEN'] = 'πŸ˜Άβ€πŸŒ«οΈ' + yield + del os.environ['DATABRICKS_HOST'] + del os.environ['DATABRICKS_TOKEN'] + + +@fixture +def oci_temp_file(): + """Mocked UC Credentials for service level account.""" + file = tempfile.NamedTemporaryFile() + os.environ['OCI_CONFIG_FILE'] = file.name + + yield + + file.close() + del os.environ['OCI_CONFIG_FILE']