diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 88d8022508..09fafe3ee5 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -268,176 +268,166 @@ def _save_checkpoint(self, state: State, logger: Logger): Path(self.save_dir_format_str) / self.huggingface_folder_name_fstr), state.run_name, state.timestamp) - dir_context_mgr = tempfile.TemporaryDirectory( - ) if self.remote_ud is not None else contextlib.nullcontext( - enter_result=save_dir) - - with dir_context_mgr as temp_save_dir: - assert isinstance(temp_save_dir, - str) # pyright doesn't know about enter_result - - log.debug('Gathering state dict') - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - if state.is_model_ddp: - composer_model = state.model.module - original_model: PreTrainedModel = state.model.module.model - state_dict_model = state.model.module.model - original_tokenizer = state.model.module.tokenizer - elif isinstance(state.model.model, FSDP): - composer_model = state.model - original_model: PreTrainedModel = state.model.model.module - state_dict_model = state.model.model - original_tokenizer = state.model.tokenizer + + temp_save_dir = tempfile.mkdtemp( + ) if self.remote_ud is not None else save_dir + + log.debug('Gathering state dict') + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + if state.is_model_ddp: + composer_model = state.model.module + original_model: PreTrainedModel = state.model.module.model + state_dict_model = state.model.module.model + original_tokenizer = state.model.module.tokenizer + elif isinstance(state.model.model, FSDP): + composer_model = state.model + original_model: PreTrainedModel = state.model.model.module + state_dict_model = state.model.model + original_tokenizer = state.model.tokenizer + else: + composer_model = state.model + original_model: PreTrainedModel = state.model.model + state_dict_model = state.model.model + original_tokenizer = state.model.tokenizer + + state_dict_context = fsdp_state_dict_type_context( + original_model, + state_dict_type='full') if ((not state.is_model_ddp) and isinstance( + state_dict_model, FSDP)) else contextlib.nullcontext() + + with state_dict_context: + state_dict = state_dict_model.state_dict() + + # convert the state dict to the requested precision + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + state_dict[k] = v.to(dtype=self.dtype) + + new_model_instance = None # Need this for pyright because variable could be unbound + + if dist.get_global_rank() == 0: + log.debug('Saving Hugging Face checkpoint in global rank 0') + + copied_config = copy.deepcopy(original_model.config) + if copied_config.model_type == 'mpt': + copied_config.attn_config['attn_impl'] = 'torch' + copied_config.init_device = 'cpu' + + log.debug(f'Creating new model instance') + + if composer_model.using_peft: + # We don't use meta here because the state dict does not contain the full + # model, only the adapter weights. + active_adapter = original_model.active_adapter + base_model = original_model.get_base_model() + new_base_model_instance = type(base_model)(copied_config) + + new_model_instance = type(original_model)( + new_base_model_instance, + original_model.peft_config[active_adapter]) + new_model_instance.to(dtype=self.dtype) else: - composer_model = state.model - original_model: PreTrainedModel = state.model.model - state_dict_model = state.model.model - original_tokenizer = state.model.tokenizer - - state_dict_context = fsdp_state_dict_type_context( - original_model, state_dict_type='full') if ( - (not state.is_model_ddp) and isinstance( - state_dict_model, FSDP)) else contextlib.nullcontext() - - with state_dict_context: - state_dict = state_dict_model.state_dict() - - # convert the state dict to the requested precision - for k, v in state_dict.items(): - if isinstance(v, torch.Tensor): - state_dict[k] = v.to(dtype=self.dtype) - - new_model_instance = None # Need this for pyright because variable could be unbound - - if dist.get_global_rank() == 0: - log.debug('Saving Hugging Face checkpoint in global rank 0') - - copied_config = copy.deepcopy(original_model.config) - if copied_config.model_type == 'mpt': - copied_config.attn_config['attn_impl'] = 'torch' - copied_config.init_device = 'cpu' - - log.debug(f'Creating new model instance') - - if composer_model.using_peft: - # We don't use meta here because the state dict does not contain the full - # model, only the adapter weights. - active_adapter = original_model.active_adapter - base_model = original_model.get_base_model() - new_base_model_instance = type(base_model)(copied_config) - - new_model_instance = type(original_model)( - new_base_model_instance, - original_model.peft_config[active_adapter]) - new_model_instance.to(dtype=self.dtype) - else: - # First create the model instance on meta device to avoid the - # initialization cost. - with init_empty_weights(): - new_model_instance = type(original_model)(copied_config) - - # Then load the state dict in with "assign" so that the state dict - # is loaded properly even though the model is initially on meta device. - new_model_instance.load_state_dict(state_dict, assign=True) - del state_dict - - log.debug('Saving Hugging Face checkpoint to disk') - new_model_instance.save_pretrained(temp_save_dir) - if original_tokenizer is not None: - assert isinstance(original_tokenizer, - PreTrainedTokenizerBase) - original_tokenizer.save_pretrained(temp_save_dir) - - # Only need to edit files for MPT because it has custom code - if original_model.config.model_type == 'mpt': - log.debug('Editing MPT files for HuggingFace compatibility') - edit_files_for_hf_compatibility( - temp_save_dir, - self.flatten_imports, + # First create the model instance on meta device to avoid the + # initialization cost. + with init_empty_weights(): + new_model_instance = type(original_model)(copied_config) + + # Then load the state dict in with "assign" so that the state dict + # is loaded properly even though the model is initially on meta device. + new_model_instance.load_state_dict(state_dict, assign=True) + del state_dict + + log.debug('Saving Hugging Face checkpoint to disk') + new_model_instance.save_pretrained(temp_save_dir) + if original_tokenizer is not None: + assert isinstance(original_tokenizer, PreTrainedTokenizerBase) + original_tokenizer.save_pretrained(temp_save_dir) + + # Only need to edit files for MPT because it has custom code + if original_model.config.model_type == 'mpt': + log.debug('Editing MPT files for HuggingFace compatibility') + edit_files_for_hf_compatibility( + temp_save_dir, + self.flatten_imports, + ) + + if self.remote_ud is not None: + for filename in os.listdir(temp_save_dir): + remote_file_name = os.path.join(save_dir, filename) + remote_file_uri = self.remote_ud.remote_backend.get_uri( + remote_file_name) + log.info( + f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' + ) + self.remote_ud.upload_file( + state=state, + remote_file_name=remote_file_name, + file_path=Path(os.path.join(temp_save_dir, filename)), + overwrite=self.overwrite, ) - if self.remote_ud is not None: - for filename in os.listdir(temp_save_dir): - remote_file_name = os.path.join(save_dir, filename) - remote_file_uri = self.remote_ud.remote_backend.get_uri( - remote_file_name) - log.info( - f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' - ) - self.remote_ud.upload_file( - state=state, - remote_file_name=remote_file_name, - file_path=Path(os.path.join(temp_save_dir, - filename)), - overwrite=self.overwrite, - ) + dist.barrier() - dist.barrier() + if dist.get_global_rank() == 0: + if self.mlflow_registered_model_name and self._is_last_batch(state): + components = {'model': new_model_instance} + if original_tokenizer is not None: + components['tokenizer'] = original_tokenizer - if dist.get_global_rank() == 0: - if self.mlflow_registered_model_name and self._is_last_batch( - state): - components = {'model': new_model_instance} - if original_tokenizer is not None: - components['tokenizer'] = original_tokenizer - - log.debug('Logging Hugging Face model to MLFlow') - for i, mlflow_logger in enumerate(self.mlflow_loggers): - log.debug( - f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}' + log.debug('Logging Hugging Face model to MLFlow') + for i, mlflow_logger in enumerate(self.mlflow_loggers): + log.debug( + f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}' + ) + local_save_path = str( + Path(temp_save_dir) / f'mlflow_save_{i}') + + # TODO: Remove after mlflow fixes the bug that makes this necessary + import mlflow + mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' + model_saving_kwargs: Dict[str, Any] = { + 'path': local_save_path + } + if composer_model.using_peft: + model_saving_kwargs['flavor'] = 'peft' + model_saving_kwargs[ + 'save_pretrained_dir'] = temp_save_dir + model_saving_kwargs[ + 'metadata'] = self.mlflow_logging_config['metadata'] + else: + model_saving_kwargs['flavor'] = 'transformers' + model_saving_kwargs['transformers_model'] = components + model_saving_kwargs.update(self.mlflow_logging_config) + + mlflow_logger.save_model(**model_saving_kwargs) + + # Upload the license file generated by mlflow during the model saving. + license_filename = _maybe_get_license_filename( + local_save_path, + self.mlflow_logging_config['metadata'].get( + 'pretrained_model_name', None)) + if license_filename is not None: + mlflow_logger._mlflow_client.log_artifact( + mlflow_logger._run_id, + os.path.join(local_save_path, license_filename), ) - local_save_path = str( - Path(temp_save_dir) / f'mlflow_save_{i}') - - # TODO: Remove after mlflow fixes the bug that makes this necessary - import mlflow - mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' - model_saving_kwargs: Dict[str, Any] = { - 'path': local_save_path - } - if composer_model.using_peft: - model_saving_kwargs['flavor'] = 'peft' - model_saving_kwargs[ - 'save_pretrained_dir'] = temp_save_dir - model_saving_kwargs[ - 'metadata'] = self.mlflow_logging_config[ - 'metadata'] - else: - model_saving_kwargs['flavor'] = 'transformers' - model_saving_kwargs[ - 'transformers_model'] = components - model_saving_kwargs.update( - self.mlflow_logging_config) - - mlflow_logger.save_model(**model_saving_kwargs) - - # Upload the license file generated by mlflow during the model saving. - license_filename = _maybe_get_license_filename( - local_save_path, - self.mlflow_logging_config['metadata'].get( - 'pretrained_model_name', None)) - if license_filename is not None: - mlflow_logger._mlflow_client.log_artifact( - mlflow_logger._run_id, - os.path.join(local_save_path, license_filename), - ) - - # Spawn a new process to register the model. - process = SpawnProcess( - target=_register_model_with_run_id_multiprocess, - kwargs={ - 'mlflow_logger': - mlflow_logger, - 'logging_level': - logging.getLogger('composer').level, - 'model_uri': - local_save_path, - 'name': - self.mlflow_registered_model_name, - 'await_creation_for': - 3600, - }) - process.start() - self.child_processes.append(process) + + # Spawn a new process to register the model. + process = SpawnProcess( + target=_register_model_with_run_id_multiprocess, + kwargs={ + 'mlflow_logger': + mlflow_logger, + 'logging_level': + logging.getLogger('composer').level, + 'model_uri': + local_save_path, + 'name': + self.mlflow_registered_model_name, + 'await_creation_for': + 3600, + }) + process.start() + self.child_processes.append(process) dist.barrier()