From 1fb5eec62b6dd8ebe53f13a1a7dd4c14a209cc20 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 05:13:12 +0200 Subject: [PATCH 01/16] pre register transform --- llmfoundry/callbacks/hf_checkpointer.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 4de7f9f2c6..4a9518d1e6 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -362,6 +362,23 @@ def transform_config( copied_config.ffn_config['moe_world_size'] = 1 return copied_config + def transform_model_pre_registration( + self, + model: PreTrainedModel, + ) -> PreTrainedModel: + """Transform the model before registering with MLFlow. + + This allows a subclass to modify the model before registering with MLFlow. The base class implementation will + make no modifications. + + Args: + model (PreTrainedModel): The model to be transformed. + + Returns: + PreTrainedModel: The transformed model. + """ + return model + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -556,6 +573,11 @@ def dtensor_to_tensor_hook( if dist.get_global_rank() == 0: if self.mlflow_registered_model_name and self._is_last_batch(state): + + new_model_instance = self.transform_model_pre_registration( + new_model_instance + ) + components = {'model': new_model_instance} if original_tokenizer is not None: components['tokenizer'] = original_tokenizer From 1cc7ca7e4a3b36ea9ba1fb785b069a63693ef7d5 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 06:21:34 +0200 Subject: [PATCH 02/16] meta --- llmfoundry/callbacks/hf_checkpointer.py | 37 ++++++++++++------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 4a9518d1e6..67f839661c 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -497,26 +497,25 @@ def dtensor_to_tensor_hook( log.debug(f'Creating new model instance') - if composer_model.using_peft: - # We don't use meta here because the state dict does not contain the full - # model, only the adapter weights. - active_adapter = original_model.active_adapter - base_model = original_model.get_base_model() - new_base_model_instance = type(base_model)(new_config) - - new_model_instance = type(original_model)( - new_base_model_instance, - original_model.peft_config[active_adapter], - ) - new_model_instance.to(dtype=self.dtype) - else: - # First create the model instance on meta device to avoid the - # initialization cost. - with init_empty_weights(): - new_model_instance = type(original_model)(new_config) - new_model_instance.generation_config.update( - **original_model.generation_config.to_dict(), + # First create the model instance on meta device to avoid the + # initialization cost. + with init_empty_weights(): + if composer_model.using_peft: + active_adapter = original_model.active_adapter + base_model = original_model.get_base_model() + new_base_model_instance = type(base_model)(new_config) + + new_model_instance = type(original_model)( + new_base_model_instance, + original_model.peft_config[active_adapter], ) + new_model_instance.to(dtype=self.dtype) + else: + with init_empty_weights(): + new_model_instance = type(original_model)(new_config) + new_model_instance.generation_config.update( + **original_model.generation_config.to_dict(), + ) # Then load the state dict in with "assign" so that the state dict # is loaded properly even though the model is initially on meta device. From 432ef34aa9a59b5517ab90118e6ea7113e1b77b2 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 07:25:36 +0200 Subject: [PATCH 03/16] yo --- llmfoundry/callbacks/hf_checkpointer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 67f839661c..b50b29e425 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -365,6 +365,7 @@ def transform_config( def transform_model_pre_registration( self, model: PreTrainedModel, + composer_model, ) -> PreTrainedModel: """Transform the model before registering with MLFlow. @@ -373,6 +374,7 @@ def transform_model_pre_registration( Args: model (PreTrainedModel): The model to be transformed. + composer_model: The composer model. Returns: PreTrainedModel: The transformed model. From d305657ea2369c2247ed7de6b4a868ab070ff244 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 07:35:54 +0200 Subject: [PATCH 04/16] yo --- llmfoundry/callbacks/hf_checkpointer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index b50b29e425..94f79b727c 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -576,7 +576,8 @@ def dtensor_to_tensor_hook( if self.mlflow_registered_model_name and self._is_last_batch(state): new_model_instance = self.transform_model_pre_registration( - new_model_instance + new_model_instance, + composer_model, ) components = {'model': new_model_instance} From ce2679b78320d6564693e50a48cfe4085f21597a Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 08:32:52 +0200 Subject: [PATCH 05/16] yo --- llmfoundry/callbacks/hf_checkpointer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 94f79b727c..f698c93126 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -425,6 +425,8 @@ def _save_checkpoint(self, state: State, logger: Logger): state_dict_model = state.model.model original_tokenizer = state.model.tokenizer + self.using_peft = composer_model.using_peft + if version.parse(torch.__version__) > version.parse('2.2.9'): from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.state_dict import ( @@ -502,7 +504,7 @@ def dtensor_to_tensor_hook( # First create the model instance on meta device to avoid the # initialization cost. with init_empty_weights(): - if composer_model.using_peft: + if self.using_peft: active_adapter = original_model.active_adapter base_model = original_model.get_base_model() new_base_model_instance = type(base_model)(new_config) @@ -599,7 +601,7 @@ def dtensor_to_tensor_hook( model_saving_kwargs: Dict[str, Any] = { 'path': local_save_path, } - if composer_model.using_peft: + if self.using_peft: model_saving_kwargs['flavor'] = 'peft' model_saving_kwargs['save_pretrained_dir' ] = temp_save_dir From aab92b319096cd22b2ba3d6992f2eda092ef3625 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 09:20:28 +0200 Subject: [PATCH 06/16] yo --- llmfoundry/callbacks/hf_checkpointer.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index f698c93126..53836c3d9a 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -365,7 +365,6 @@ def transform_config( def transform_model_pre_registration( self, model: PreTrainedModel, - composer_model, ) -> PreTrainedModel: """Transform the model before registering with MLFlow. @@ -374,7 +373,6 @@ def transform_model_pre_registration( Args: model (PreTrainedModel): The model to be transformed. - composer_model: The composer model. Returns: PreTrainedModel: The transformed model. @@ -484,7 +482,7 @@ def dtensor_to_tensor_hook( with state_dict_context: state_dict = state_dict_model.state_dict() - # Convert the state dict to the requested precis + # Convert the state dict to the requested precision for k, v in state_dict.items(): if isinstance(v, torch.Tensor): state_dict[k] = v.to(dtype=self.dtype) @@ -515,11 +513,10 @@ def dtensor_to_tensor_hook( ) new_model_instance.to(dtype=self.dtype) else: - with init_empty_weights(): - new_model_instance = type(original_model)(new_config) - new_model_instance.generation_config.update( - **original_model.generation_config.to_dict(), - ) + new_model_instance = type(original_model)(new_config) + new_model_instance.generation_config.update( + **original_model.generation_config.to_dict(), + ) # Then load the state dict in with "assign" so that the state dict # is loaded properly even though the model is initially on meta device. @@ -579,7 +576,6 @@ def dtensor_to_tensor_hook( new_model_instance = self.transform_model_pre_registration( new_model_instance, - composer_model, ) components = {'model': new_model_instance} From 04dbf16f5f99c4414c3a5c2da8003720973bbb32 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 09:31:34 +0200 Subject: [PATCH 07/16] log --- llmfoundry/callbacks/hf_checkpointer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 53836c3d9a..457af1cbfa 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -521,6 +521,8 @@ def dtensor_to_tensor_hook( # Then load the state dict in with "assign" so that the state dict # is loaded properly even though the model is initially on meta device. new_model_instance.load_state_dict(state_dict, assign=True) + print(state_dict.keys()) + exit(0) del state_dict # Transform the model and tokenizer before saving From 8cd4f5f47575b5cb52235329b19a171b7f8333ce Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 10:05:36 +0200 Subject: [PATCH 08/16] log --- llmfoundry/callbacks/hf_checkpointer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 457af1cbfa..d13a3290c6 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -481,6 +481,9 @@ def dtensor_to_tensor_hook( FSDP)) else contextlib.nullcontext() with state_dict_context: state_dict = state_dict_model.state_dict() + + print("State dict model type:", type(state_dict_model)) + print("is it a hugging face model?", isinstance(state_dict_model, HuggingFaceModel)) # Convert the state dict to the requested precision for k, v in state_dict.items(): From 2e979e43928c530b6da09ad2ca3c2e5803e2b6fc Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 10:15:56 +0200 Subject: [PATCH 09/16] yo --- llmfoundry/callbacks/hf_checkpointer.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index d13a3290c6..5026aaf5b3 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -481,9 +481,6 @@ def dtensor_to_tensor_hook( FSDP)) else contextlib.nullcontext() with state_dict_context: state_dict = state_dict_model.state_dict() - - print("State dict model type:", type(state_dict_model)) - print("is it a hugging face model?", isinstance(state_dict_model, HuggingFaceModel)) # Convert the state dict to the requested precision for k, v in state_dict.items(): @@ -524,8 +521,6 @@ def dtensor_to_tensor_hook( # Then load the state dict in with "assign" so that the state dict # is loaded properly even though the model is initially on meta device. new_model_instance.load_state_dict(state_dict, assign=True) - print(state_dict.keys()) - exit(0) del state_dict # Transform the model and tokenizer before saving @@ -579,9 +574,7 @@ def dtensor_to_tensor_hook( if dist.get_global_rank() == 0: if self.mlflow_registered_model_name and self._is_last_batch(state): - new_model_instance = self.transform_model_pre_registration( - new_model_instance, - ) + new_model_instance = self.transform_model_pre_registration(new_model_instance) components = {'model': new_model_instance} if original_tokenizer is not None: From acd1f88070ac4100afb44453dcb52fc7fc855312 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 10:21:23 +0200 Subject: [PATCH 10/16] ay --- llmfoundry/callbacks/hf_checkpointer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 5026aaf5b3..73eee640b7 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -574,7 +574,9 @@ def dtensor_to_tensor_hook( if dist.get_global_rank() == 0: if self.mlflow_registered_model_name and self._is_last_batch(state): - new_model_instance = self.transform_model_pre_registration(new_model_instance) + new_model_instance = self.transform_model_pre_registration( + new_model_instance + ) components = {'model': new_model_instance} if original_tokenizer is not None: From 3955df652a1c9db12379167160ebe2be8b43c1d2 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 10:36:38 +0200 Subject: [PATCH 11/16] yo --- llmfoundry/callbacks/hf_checkpointer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 73eee640b7..a90f74a591 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -179,6 +179,7 @@ def __init__( 'bfloat16': torch.bfloat16, }[precision] self.flatten_imports = flatten_imports + self.using_peft = False # mlflow config setup self.mlflow_registered_model_name = mlflow_registered_model_name From 49c5ddd54c991d5bc03aa720819eaef4b3cf84b8 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 19 Jul 2024 10:37:48 +0200 Subject: [PATCH 12/16] nice --- llmfoundry/callbacks/hf_checkpointer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index a90f74a591..1ebcdcf42d 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -576,7 +576,7 @@ def dtensor_to_tensor_hook( if self.mlflow_registered_model_name and self._is_last_batch(state): new_model_instance = self.transform_model_pre_registration( - new_model_instance + new_model_instance, ) components = {'model': new_model_instance} From f694ea7d48c61e5b911481dc82c955e735a30315 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Sat, 20 Jul 2024 05:19:15 +0200 Subject: [PATCH 13/16] test --- llmfoundry/callbacks/hf_checkpointer.py | 17 ++++++++++------- .../inference/test_convert_composer_to_hf.py | 3 +++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 1ebcdcf42d..82ca927f74 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -17,6 +17,7 @@ import numpy as np import torch import torch.nn as nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from composer.core import Callback, Event, Precision, State, Time, TimeUnit from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger @@ -275,6 +276,15 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( '1GB', ) + + # Check if the model is using PEFT + if state.is_model_ddp: + composer_model = state.model.module + elif isinstance(state.model.model, FSDP): + composer_model = state.model + else: + composer_model = state.model + self.using_peft = composer_model.using_peft elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. timeout = 3600 @@ -406,26 +416,20 @@ def _save_checkpoint(self, state: State, logger: Logger): temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir log.debug('Gathering state dict') - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP if state.is_model_ddp: - composer_model = state.model.module original_model: PreTrainedModel = state.model.module.model state_dict_model = state.model.module.model original_tokenizer = state.model.module.tokenizer elif isinstance(state.model.model, FSDP): - composer_model = state.model original_model: PreTrainedModel = state.model.model.module state_dict_model = state.model.model original_tokenizer = state.model.tokenizer else: - composer_model = state.model original_model: PreTrainedModel = state.model.model state_dict_model = state.model.model original_tokenizer = state.model.tokenizer - self.using_peft = composer_model.using_peft - if version.parse(torch.__version__) > version.parse('2.2.9'): from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.state_dict import ( @@ -512,7 +516,6 @@ def dtensor_to_tensor_hook( new_base_model_instance, original_model.peft_config[active_adapter], ) - new_model_instance.to(dtype=self.dtype) else: new_model_instance = type(original_model)(new_config) new_model_instance.generation_config.update( diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 2ef458fece..7da063fdcd 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -383,6 +383,7 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' + transform_model_pre_reg_mock = MagicMock(wraps = checkpointer_callback.transform_model_pre_registration) trainer = Trainer( model=original_model, device='gpu', @@ -407,8 +408,10 @@ def test_huggingface_conversion_callback_interval( input_example=ANY, metadata={}, ) + assert transform_model_pre_reg_mock.call_count == 1 assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: + assert transform_model_pre_reg_mock.call_count == 0 assert mlflow_logger_mock.save_model.call_count == 0 assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 From c870af76e46a28a143385e74a73192d2c11ec82a Mon Sep 17 00:00:00 2001 From: Saaketh Date: Sat, 20 Jul 2024 06:08:21 +0200 Subject: [PATCH 14/16] test --- llmfoundry/callbacks/hf_checkpointer.py | 105 ++++++++---------- .../inference/test_convert_composer_to_hf.py | 8 +- 2 files changed, 51 insertions(+), 62 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 82ca927f74..d53d159090 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -17,9 +17,7 @@ import numpy as np import torch import torch.nn as nn -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from composer.core import Callback, Event, Precision, State, Time, TimeUnit -from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger from composer.models import HuggingFaceModel from composer.utils import ( @@ -30,7 +28,12 @@ ) from composer.utils.misc import create_interval_scheduler from mlflow.transformers import _fetch_model_card, _write_license_information -from packaging import version +from torch.distributed._tensor import DTensor +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import ( PretrainedConfig, PreTrainedModel, @@ -276,7 +279,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( '1GB', ) - + # Check if the model is using PEFT if state.is_model_ddp: composer_model = state.model.module @@ -430,62 +433,46 @@ def _save_checkpoint(self, state: State, logger: Logger): state_dict_model = state.model.model original_tokenizer = state.model.tokenizer - if version.parse(torch.__version__) > version.parse('2.2.9'): - from torch.distributed._tensor import DTensor - from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - get_model_state_dict, - ) - cpu_offload = True - - # Add a dtensor->cpu tensor hook to avoid CUDA OOM - def dtensor_to_tensor_hook( - module: nn.Module, - state_dict: Dict[str, Any], - prefix: str, - *args: Any, - ) -> Dict[str, Any]: - dtensor_fqns = [] - for fqn in state_dict.keys(): - tensor = state_dict[fqn] - if isinstance(tensor, DTensor): - dtensor_fqns.append(fqn) - tensor = tensor.full_tensor() # type: ignore - if dist.get_global_rank() == 0: - if cpu_offload: - tensor = tensor.cpu() - state_dict[fqn] = tensor - if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] - return state_dict - - hooks = [] - for _, module in state_dict_model.named_modules(): - if isinstance(module, FSDP): - hooks.append( - module. - _register_state_dict_hook(dtensor_to_tensor_hook), - ) + cpu_offload = True + + # Add a dtensor->cpu tensor hook to avoid CUDA OOM + def dtensor_to_tensor_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + dtensor_fqns = [] + for fqn in state_dict.keys(): + tensor = state_dict[fqn] + if isinstance(tensor, DTensor): + dtensor_fqns.append(fqn) + tensor = tensor.full_tensor() # type: ignore + if dist.get_global_rank() == 0: + if cpu_offload: + tensor = tensor.cpu() + state_dict[fqn] = tensor + if dist.get_global_rank() != 0: + for fqn in dtensor_fqns: + del state_dict[fqn] + return state_dict + + hooks = [] + for _, module in state_dict_model.named_modules(): + if isinstance(module, FSDP): + hooks.append( + module._register_state_dict_hook(dtensor_to_tensor_hook), + ) - state_dict = get_model_state_dict( - state_dict_model, - options=StateDictOptions( - full_state_dict=True, - cpu_offload=cpu_offload, - ), - ) - for hook in hooks: - hook.remove() - else: - state_dict_context = fsdp_state_dict_type_context( - original_model, - state_dict_type='full', - ) if ((not state.is_model_ddp) and - isinstance(state_dict_model, - FSDP)) else contextlib.nullcontext() - with state_dict_context: - state_dict = state_dict_model.state_dict() + state_dict = get_model_state_dict( + state_dict_model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=cpu_offload, + ), + ) + for hook in hooks: + hook.remove() # Convert the state dict to the requested precision for k, v in state_dict.items(): diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 7da063fdcd..68dc855154 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -383,7 +383,9 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' - transform_model_pre_reg_mock = MagicMock(wraps = checkpointer_callback.transform_model_pre_registration) + checkpointer_callback.transform_model_pre_registration = MagicMock( + wraps=checkpointer_callback.transform_model_pre_registration, + ) trainer = Trainer( model=original_model, device='gpu', @@ -408,10 +410,10 @@ def test_huggingface_conversion_callback_interval( input_example=ANY, metadata={}, ) - assert transform_model_pre_reg_mock.call_count == 1 + assert checkpointer_callback.transform_model_pre_registration.call_count == 1 assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: - assert transform_model_pre_reg_mock.call_count == 0 + assert checkpointer_callback.transform_model_pre_registration.call_count == 0 assert mlflow_logger_mock.save_model.call_count == 0 assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 From 500c500ab260e404599caabe8913d21ee004b33c Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Fri, 19 Jul 2024 21:52:03 -0700 Subject: [PATCH 15/16] Update llmfoundry/callbacks/hf_checkpointer.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/callbacks/hf_checkpointer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index d53d159090..3b1733b8aa 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -382,7 +382,7 @@ def transform_model_pre_registration( ) -> PreTrainedModel: """Transform the model before registering with MLFlow. - This allows a subclass to modify the model before registering with MLFlow. The base class implementation will + This allows a subclass to modify the model before registering with MLflow. The base class implementation will make no modifications. Args: From 212844b339d3775cbff4b51158147778e6463e31 Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Fri, 19 Jul 2024 21:52:08 -0700 Subject: [PATCH 16/16] Update llmfoundry/callbacks/hf_checkpointer.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/callbacks/hf_checkpointer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 3b1733b8aa..2ade458bb4 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -380,7 +380,7 @@ def transform_model_pre_registration( self, model: PreTrainedModel, ) -> PreTrainedModel: - """Transform the model before registering with MLFlow. + """Transform the model before registering with MLflow. This allows a subclass to modify the model before registering with MLflow. The base class implementation will make no modifications.