diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index c0db8e1c28..f9618c5fa2 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -95,7 +95,9 @@ class HuggingFaceCheckpointer(Callback): that will get passed along to the MLflow ``save_model`` call. Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and - ``{'task': 'llm/v1/completions'}`` respectively. + ``{'task': 'llm/v1/completions'}`` respectively. A default input example + and signature intended for text generation is also included under the + keys ``input_example`` and ``signature``. flatten_imports (Sequence[str]): A sequence of import prefixes that will be flattened when editing MPT files. """ @@ -126,6 +128,10 @@ def __init__( if mlflow_logging_config is None: mlflow_logging_config = {} if self.mlflow_registered_model_name is not None: + import numpy as np + from mlflow.models.signature import ModelSignature + from mlflow.types.schema import ColSpec, Schema + # Both the metadata and the task are needed in order for mlflow # and databricks optimized model serving to work default_metadata = {'task': 'llm/v1/completions'} @@ -135,6 +141,28 @@ def __init__( **passed_metadata } mlflow_logging_config.setdefault('task', 'text-generation') + + # Define a default input/output that is good for standard text generation LMs + input_schema = Schema([ + ColSpec('string', 'prompt'), + ColSpec('double', 'temperature', optional=True), + ColSpec('integer', 'max_tokens', optional=True), + ColSpec('string', 'stop', optional=True), + ColSpec('integer', 'candidate_count', optional=True) + ]) + + output_schema = Schema([ColSpec('string', 'predictions')]) + + default_signature = ModelSignature(inputs=input_schema, + outputs=output_schema) + + default_input_example = { + 'prompt': np.array(['What is Machine Learning?']) + } + mlflow_logging_config.setdefault('input_example', + default_input_example) + mlflow_logging_config.setdefault('signature', default_signature) + self.mlflow_logging_config = mlflow_logging_config self.huggingface_folder_name_fstr = os.path.join( diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index ba99cd43b3..ab2d569132 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -253,12 +253,12 @@ def test_callback_inits(): save_folder='test', save_interval='1ba', mlflow_registered_model_name='test_model_name') - assert hf_checkpointer.mlflow_logging_config == { - 'task': 'text-generation', - 'metadata': { - 'task': 'llm/v1/completions' - } - } + + assert hf_checkpointer.mlflow_logging_config['task'] == 'text-generation' + assert hf_checkpointer.mlflow_logging_config['metadata'][ + 'task'] == 'llm/v1/completions' + assert 'input_example' in hf_checkpointer.mlflow_logging_config + assert 'signature' in hf_checkpointer.mlflow_logging_config @pytest.mark.gpu @@ -331,6 +331,8 @@ def test_huggingface_conversion_callback_interval( transformers_model=ANY, path=ANY, task='text-generation', + input_example=ANY, + signature=ANY, metadata={'task': 'llm/v1/completions'}) assert mlflow_logger_mock.register_model.call_count == 1 else: @@ -593,11 +595,34 @@ def test_huggingface_conversion_callback( } } else: + import numpy as np + from mlflow.models.signature import ModelSignature + from mlflow.types.schema import ColSpec, Schema + + input_schema = Schema([ + ColSpec('string', 'prompt'), + ColSpec('double', 'temperature', optional=True), + ColSpec('integer', 'max_tokens', optional=True), + ColSpec('string', 'stop', optional=True), + ColSpec('integer', 'candidate_count', optional=True) + ]) + + output_schema = Schema([ColSpec('string', 'predictions')]) + + default_signature = ModelSignature(inputs=input_schema, + outputs=output_schema) + + default_input_example = { + 'prompt': np.array(['What is Machine Learning?']) + } + expectation = { 'flavor': 'transformers', 'transformers_model': ANY, 'path': ANY, 'task': 'text-generation', + 'signature': default_signature, + 'input_example': default_input_example, 'metadata': { 'task': 'llm/v1/completions' }