Skip to content

Commit

Permalink
fix bug and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 25, 2023
1 parent d72902a commit 9f1bfee
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
18 changes: 16 additions & 2 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import tempfile
from pathlib import Path
from typing import Optional, Union
import math

import torch
from composer.core import Callback, Event, State, Time
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
Expand Down Expand Up @@ -83,6 +84,13 @@ def __init__(

self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)

if isinstance(save_interval, str):
save_interval = Time.from_timestring(save_interval)
if isinstance(save_interval, int):
save_interval = Time(save_interval, TimeUnit.EPOCH)

self.save_interval = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.upload_to_object_store = (self.backend != '')
Expand Down Expand Up @@ -225,7 +233,13 @@ def _save_checkpoint(self, state: State, logger: Logger):
)

elapsed_duration = state.get_elapsed_duration()
if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:

# If the save interval is specified as 1dur, and the max duration is in epoch units
# we need a special case to identify we are on the last batch and should write the mlflow checkpoint
is_last_batch = False
if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH:
is_last_batch = int(state.timestamp.batch) % math.ceil(state.max_duration.value * state.dataloader_len) == 0
if self.mlflow_registered_model_name is not None and ((elapsed_duration is not None and elapsed_duration >= 1.0) or is_last_batch):
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
Expand Down
28 changes: 15 additions & 13 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,25 +251,29 @@ def test_callback_inits_with_defaults():
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize('log_to_mlflow', [True, False])
@pytest.mark.parametrize('hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)])
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_state_dict_type: Optional[str],
log_to_mlflow: bool):
log_to_mlflow: bool,
hf_save_interval: str,
save_interval: str,
max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))

max_seq_len = 16
save_interval_batches = 2
huggingface_save_interval_batches = 3
device_batch_size = 1
dataset_size = 14
max_duration_batches = 7
precision_str = 'bfloat16'
precision = torch.bfloat16
batches_per_epoch = math.ceil(dataset_size / (device_batch_size*2))

checkpointer_callback = HuggingFaceCheckpointer(
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{huggingface_save_interval_batches}ba',
save_interval=hf_save_interval,
precision=precision_str,
mlflow_registered_model_name='dummy-registered-name'
if log_to_mlflow else None,
Expand Down Expand Up @@ -405,8 +409,8 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{save_interval_batches}ba',
max_duration=f'{max_duration_batches}ba',
save_interval=save_interval,
max_duration=max_duration,
callbacks=[checkpointer_callback],
loggers=[mlflow_logger_mock] if log_to_mlflow else [],
optimizers=optimizer,
Expand Down Expand Up @@ -442,15 +446,13 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
name for name in os.listdir(
os.path.join(tmp_path, 'checkpoints', 'huggingface'))
]
assert len(normal_checkpoints) == math.ceil(max_duration_batches /
save_interval_batches)
assert len(huggingface_checkpoints) == math.ceil(
max_duration_batches / huggingface_save_interval_batches)
assert len(normal_checkpoints) == expected_normal_checkpoints
assert len(huggingface_checkpoints) == expected_hf_checkpoints

# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{max_duration_batches}'),
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)

Expand All @@ -471,7 +473,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,

loaded_tokenizer = transformers.AutoTokenizer.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{max_duration_batches}'),
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)

Expand Down

0 comments on commit 9f1bfee

Please sign in to comment.