Skip to content

Commit

Permalink
fix style etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Jan 10, 2024
1 parent 4772ba2 commit db8be5e
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 2 deletions.
9 changes: 7 additions & 2 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,16 @@ def _save_checkpoint(self, state: State, logger: Logger):
)

if self.remote_ud is not None:
log.info(f'Uploading HuggingFace formatted checkpoint')
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}'
)
self.remote_ud.upload_file(
state=state,
remote_file_name=os.path.join(save_dir, filename),
remote_file_name=remote_file_name,
file_path=Path(os.path.join(temp_save_dir,
filename)),
overwrite=self.overwrite,
Expand Down
92 changes: 92 additions & 0 deletions tests/callbacks/test_hf_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, List
from unittest.mock import patch

from composer.core import State, Time, TimeUnit
from composer.loggers import Logger

from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM

dummy_s3_path = 's3://dummy/path'
dummy_oci_path = 'oci://dummypath'
dummy_gc_path = 'gs://dummy/path'
dummy_uc_path = 'dbfs://dummypath/Volumes/the_catalog/the_schema/yada_yada'

dummy_save_interval = Time(1, TimeUnit.EPOCH)


def dummy_log_info(log_output: List[str]):
def _dummy_log_info(*msgs: str):
log_output.extend(msgs)

return _dummy_log_info


def dummy_upload_file(*_, **__: Dict[str, Any]):
pass


def assert_checkpoint_saves_to_uri(
uri: str, build_tiny_hf_mpt: Callable[..., ComposerMPTCausalLM]):
uri_base = uri.split('://')[0]
model = build_tiny_hf_mpt()

dummy_state = State(model=model,
rank_zero_seed=42,
run_name='dummy_run',
device='cpu')
dummy_logger = Logger(dummy_state)
# mock the State and Logger
logs = []
with patch('logging.Logger.info', dummy_log_info(logs)):
my_checkpointer = HuggingFaceCheckpointer(
save_folder=uri, save_interval=dummy_save_interval)
my_checkpointer.remote_ud.upload_file = dummy_upload_file
my_checkpointer._save_checkpoint(dummy_state, dummy_logger)

assert any([uri_base in str(log) for log in logs])


def test_checkpoint_saves_to_s3(
build_tiny_hf_mpt: Callable[..., ComposerMPTCausalLM]):
assert_checkpoint_saves_to_uri(dummy_s3_path, build_tiny_hf_mpt)


class DummyData:

def __init__(self, *_, **__: Any):
self.data = '🪐'
pass


class DummyClient:

def __init__(self, *_, **__: Any):
pass

def get_namespace(self, *_, **__: Any):
return DummyData()


def test_checkpoint_saves_to_oci(
build_tiny_hf_mpt: Callable[...,
ComposerMPTCausalLM], oci_temp_file: None):
with patch('oci.config.from_file', lambda _: {}), \
patch('oci.object_storage.ObjectStorageClient', lambda *_, **__: DummyClient()), \
patch('oci.object_storage.UploadManager', lambda *_, **__: None):
assert_checkpoint_saves_to_uri(dummy_oci_path, build_tiny_hf_mpt)


def test_checkpoint_saves_to_gc(
build_tiny_hf_mpt: Callable[..., ComposerMPTCausalLM],
gcs_account_credentials: None):
assert_checkpoint_saves_to_uri(dummy_gc_path, build_tiny_hf_mpt)


def test_checkpoint_saves_to_uc(
build_tiny_hf_mpt: Callable[..., ComposerMPTCausalLM],
uc_account_credentials: None):
assert_checkpoint_saves_to_uri(dummy_uc_path, build_tiny_hf_mpt)
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

# Add the path of any pytest fixture files you want to make global
pytest_plugins = [
'tests.fixtures.object_stores',
'tests.fixtures.autouse',
'tests.fixtures.models',
'tests.fixtures.data',
Expand Down
39 changes: 39 additions & 0 deletions tests/fixtures/object_stores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
import tempfile

from pytest import fixture


@fixture
def gcs_account_credentials():
"""Mocked GCS Credentials for service level account."""
os.environ['GCS_KEY'] = '🗝️'
os.environ['GCS_SECRET'] = '🤫'
yield
del os.environ['GCS_KEY']
del os.environ['GCS_SECRET']


@fixture
def uc_account_credentials():
"""Mocked UC Credentials for service level account."""
os.environ['DATABRICKS_HOST'] = '⛵️'
os.environ['DATABRICKS_TOKEN'] = '😶‍🌫️'
yield
del os.environ['DATABRICKS_HOST']
del os.environ['DATABRICKS_TOKEN']


@fixture
def oci_temp_file():
"""Mocked UC Credentials for service level account."""
file = tempfile.NamedTemporaryFile()
os.environ['OCI_CONFIG_FILE'] = file.name

yield

file.close()
del os.environ['OCI_CONFIG_FILE']

0 comments on commit db8be5e

Please sign in to comment.