Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mlflow model logging bug #692

Merged
merged 5 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import contextlib
import copy
import logging
import math
import os
import tempfile
from pathlib import Path
from typing import Optional, Union

import torch
from composer.core import Callback, Event, State, Time
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
Expand Down Expand Up @@ -83,6 +84,13 @@ def __init__(

self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)

if isinstance(save_interval, str):
save_interval = Time.from_timestring(save_interval)
if isinstance(save_interval, int):
save_interval = Time(save_interval, TimeUnit.EPOCH)

self.save_interval = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.upload_to_object_store = (self.backend != '')
Expand Down Expand Up @@ -225,7 +233,18 @@ def _save_checkpoint(self, state: State, logger: Logger):
)

elapsed_duration = state.get_elapsed_duration()
if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:

# If the save interval is specified as 1dur, and the max duration is in epoch units
# we need a special case to identify we are on the last batch and should write the mlflow checkpoint
is_last_batch = False
assert state.max_duration is not None # for pyright
if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH:
assert state.dataloader_len is not None # for pyright
is_last_batch = int(state.timestamp.batch) % math.ceil(
state.max_duration.value * state.dataloader_len) == 0
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
if self.mlflow_registered_model_name is not None and (
(elapsed_duration is not None and
elapsed_duration >= 1.0) or is_last_batch):
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
Expand Down
29 changes: 16 additions & 13 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,25 +251,30 @@ def test_callback_inits_with_defaults():
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize('log_to_mlflow', [True, False])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)])
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_state_dict_type: Optional[str],
log_to_mlflow: bool):
log_to_mlflow: bool,
hf_save_interval: str,
save_interval: str, max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))

max_seq_len = 16
save_interval_batches = 2
huggingface_save_interval_batches = 3
device_batch_size = 1
dataset_size = 14
max_duration_batches = 7
precision_str = 'bfloat16'
precision = torch.bfloat16
batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2))

checkpointer_callback = HuggingFaceCheckpointer(
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{huggingface_save_interval_batches}ba',
save_interval=hf_save_interval,
precision=precision_str,
mlflow_registered_model_name='dummy-registered-name'
if log_to_mlflow else None,
Expand Down Expand Up @@ -405,8 +410,8 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{save_interval_batches}ba',
max_duration=f'{max_duration_batches}ba',
save_interval=save_interval,
max_duration=max_duration,
callbacks=[checkpointer_callback],
loggers=[mlflow_logger_mock] if log_to_mlflow else [],
optimizers=optimizer,
Expand Down Expand Up @@ -442,15 +447,13 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
name for name in os.listdir(
os.path.join(tmp_path, 'checkpoints', 'huggingface'))
]
assert len(normal_checkpoints) == math.ceil(max_duration_batches /
save_interval_batches)
assert len(huggingface_checkpoints) == math.ceil(
max_duration_batches / huggingface_save_interval_batches)
assert len(normal_checkpoints) == expected_normal_checkpoints
assert len(huggingface_checkpoints) == expected_hf_checkpoints

# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{max_duration_batches}'),
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)

Expand All @@ -471,7 +474,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,

loaded_tokenizer = transformers.AutoTokenizer.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{max_duration_batches}'),
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)

Expand Down