Skip to content

Commit

Permalink
Add tests for save_folder=None
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Dec 6, 2024
1 parent 7d2a3c2 commit cdd3f7e
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 78 deletions.
48 changes: 24 additions & 24 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,32 +316,32 @@ def __init__(
# mlflow config setup
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
passed_metadata = mlflow_logging_config.get('metadata', {})
mlflow_logging_config['metadata'] = passed_metadata
mlflow_logging_config.setdefault('task', 'llm/v1/completions')

# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
passed_metadata = mlflow_logging_config.get('metadata', {})
mlflow_logging_config['metadata'] = passed_metadata
mlflow_logging_config.setdefault('task', 'llm/v1/completions')

default_input_example = {
'prompt': np.array(['What is Machine Learning?']),
}
is_chat = mlflow_logging_config['task'].endswith('chat') or (
mlflow_logging_config['metadata'] is not None and
mlflow_logging_config['metadata'].get('task',
'').endswith('chat')
)
if is_chat:
default_input_example = {
'prompt': np.array(['What is Machine Learning?']),
'messages': [{
'role': 'user',
'content': 'What is Machine Learning?',
}],
}
is_chat = mlflow_logging_config['task'].endswith('chat') or (
mlflow_logging_config['metadata'] is not None and
mlflow_logging_config['metadata'].get('task',
'').endswith('chat')
)
if is_chat:
default_input_example = {
'messages': [{
'role': 'user',
'content': 'What is Machine Learning?',
}],
}
mlflow_logging_config.setdefault(
'input_example',
default_input_example,
)
mlflow_logging_config.setdefault(
'input_example',
default_input_example,
)

self.mlflow_logging_config = mlflow_logging_config
if 'metadata' in self.mlflow_logging_config:
Expand Down Expand Up @@ -431,7 +431,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
]
if len(self.mlflow_loggers) == 0:
raise ValueError(
f'Got {self.mlflow_registered_model_name=} and {self.save_folder}, but no `MLFlowLogger` was found in the `logger.destinations` list. '
f'Got {self.mlflow_registered_model_name=} and {self.save_folder=}, but no `MLFlowLogger` was found in the `logger.destinations` list. '
+
'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None` and set `save_folder`',
)
Expand Down
113 changes: 59 additions & 54 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def test_final_register_only(

@pytest.mark.gpu
@pytest.mark.parametrize('log_to_mlflow', [True, False])
@pytest.mark.parametrize('hf_save_folder', [True, False])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)],
Expand All @@ -495,6 +496,7 @@ def test_final_register_only(
def test_huggingface_conversion_callback_interval(
tmp_path: pathlib.Path,
log_to_mlflow: bool,
hf_save_folder: bool,
hf_save_interval: str,
save_interval: str,
max_duration: str,
Expand All @@ -515,7 +517,7 @@ def test_huggingface_conversion_callback_interval(
batches_per_epoch = math.ceil(dataset_size / device_batch_size)

checkpointer_callback = HuggingFaceCheckpointer(
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_folder=os.path.join(tmp_path, 'checkpoints') if hf_save_folder else None,
save_interval=hf_save_interval,
precision=precision_str,
mlflow_registered_model_name='dummy-registered-name'
Expand All @@ -542,19 +544,25 @@ def test_huggingface_conversion_callback_interval(
save_interval=save_interval,
max_duration=max_duration,
callbacks=[checkpointer_callback],
loggers=[mlflow_logger_mock] if log_to_mlflow else [],
loggers=[mlflow_logger_mock] if log_to_mlflow or not hf_save_folder else [],
optimizers=optimizer,
save_latest_filename=None,
)
trainer.fit()

if log_to_mlflow:
assert mlflow_logger_mock.log_model.call_count == 1
if hf_save_folder and not log_to_mlflow:
assert checkpointer_callback.transform_model_pre_registration.call_count == 0
assert checkpointer_callback.pre_register_edit.call_count == 0
assert mlflow_logger_mock.log_model.call_count == 0
else:
expected_call_count = 1 if hf_save_folder else expected_hf_checkpoints
expected_registered_model_name = 'dummy-registered-name' if log_to_mlflow else None
assert mlflow_logger_mock.log_model.call_count == expected_call_count
mlflow_logger_mock.log_model.assert_called_with(
transformers_model=ANY,
flavor='transformers',
artifact_path=f'huggingface/ba{trainer.state.timestamp.batch.value}',
registered_model_name='dummy-registered-name',
registered_model_name=expected_registered_model_name,
run_id='mlflow-run-id',
await_registration_for=3600,
metadata=ANY,
Expand All @@ -563,64 +571,61 @@ def test_huggingface_conversion_callback_interval(
'prompt': np.array(['What is Machine Learning?']),
},
)
assert checkpointer_callback.transform_model_pre_registration.call_count == 1
assert checkpointer_callback.pre_register_edit.call_count == 1
assert mlflow_logger_mock.log_model.call_count == 1
else:
assert checkpointer_callback.transform_model_pre_registration.call_count == 0
assert checkpointer_callback.pre_register_edit.call_count == 0
assert mlflow_logger_mock.log_model.call_count == 0
assert checkpointer_callback.transform_model_pre_registration.call_count == expected_call_count
assert checkpointer_callback.pre_register_edit.call_count == expected_call_count
assert mlflow_logger_mock.log_model.call_count == expected_call_count

normal_checkpoints = [
name for name in os.listdir(os.path.join(tmp_path, 'checkpoints'))
if name != 'huggingface'
]

huggingface_checkpoints = list(
os.listdir(os.path.join(tmp_path, 'checkpoints', 'huggingface')),
)
assert len(normal_checkpoints) == expected_normal_checkpoints
assert len(huggingface_checkpoints) == expected_hf_checkpoints

# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(
tmp_path,
'checkpoints',
'huggingface',
f'ba{batches_per_epoch}',
),
trust_remote_code=True,
)
if hf_save_folder:
huggingface_checkpoints = list(
os.listdir(os.path.join(tmp_path, 'checkpoints', 'huggingface')),
)
assert len(huggingface_checkpoints) == expected_hf_checkpoints

# Check that the loaded model has the correct precision, and then set it back
# to the original for the equivalence check
assert loaded_model.config.torch_dtype == precision
loaded_model.config.torch_dtype = original_model.model.config.torch_dtype

# Check that we have correctly set these attributes, and then set them back
# to the original for the equivalence check
assert loaded_model.config.attn_config['attn_impl'] == 'torch'
assert loaded_model.config.init_device == 'cpu'
loaded_model.config.attn_config[
'attn_impl'] = original_model.model.config.attn_config['attn_impl']
loaded_model.config.init_device = original_model.model.config.init_device

loaded_tokenizer = transformers.AutoTokenizer.from_pretrained(
os.path.join(
tmp_path,
'checkpoints',
'huggingface',
f'ba{batches_per_epoch}',
),
trust_remote_code=True,
)
# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(
tmp_path,
'checkpoints',
'huggingface',
f'ba{batches_per_epoch}',
),
trust_remote_code=True,
)

check_hf_model_equivalence(
trainer.state.model.model.to(precision),
loaded_model,
)
check_hf_tokenizer_equivalence(mpt_tokenizer, loaded_tokenizer)
# Check that the loaded model has the correct precision, and then set it back
# to the original for the equivalence check
assert loaded_model.config.torch_dtype == precision
loaded_model.config.torch_dtype = original_model.model.config.torch_dtype

# Check that we have correctly set these attributes, and then set them back
# to the original for the equivalence check
assert loaded_model.config.attn_config['attn_impl'] == 'torch'
assert loaded_model.config.init_device == 'cpu'
loaded_model.config.attn_config[
'attn_impl'] = original_model.model.config.attn_config['attn_impl']
loaded_model.config.init_device = original_model.model.config.init_device

loaded_tokenizer = transformers.AutoTokenizer.from_pretrained(
os.path.join(
tmp_path,
'checkpoints',
'huggingface',
f'ba{batches_per_epoch}',
),
trust_remote_code=True,
)

check_hf_model_equivalence(
trainer.state.model.model.to(precision),
loaded_model,
)
check_hf_tokenizer_equivalence(mpt_tokenizer, loaded_tokenizer)

delete_transformers_cache()

Expand Down

0 comments on commit cdd3f7e

Please sign in to comment.