diff --git a/diffusion/planners/__init__.py b/diffusion/planners/__init__.py new file mode 100644 index 00000000..efafbb0a --- /dev/null +++ b/diffusion/planners/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Composer checkpointing planners.""" + +from diffusion.planners.lora_planner import LoraPlanner + +__all__ = ['LoraPlanner'] diff --git a/diffusion/planners/lora_planner.py b/diffusion/planners/lora_planner.py new file mode 100644 index 00000000..3f856bd1 --- /dev/null +++ b/diffusion/planners/lora_planner.py @@ -0,0 +1,58 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""LoRA Planner.""" +from torch.distributed.checkpoint._nested_dict import flatten_state_dict +from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata + +__all__ = ['LoraPlanner'] + + +class LoraPlanner(DefaultLoadPlanner): + """Takes a Composer checkpoint and converts it to LoRA Checkpoint.""" + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Metadata, + is_coordinator: bool, + ) -> None: + """Sets up the planner for converting Composer to LoRA Checkpoint. + + Takes all targeted modules and checks whether they have been LoRA processed. If not, + changes names of weights appropriately. If yes, doesn't change anything for autoresume + compatibility. + + Args: + state_dict (STATE_DICT_TYPE): Original torch state dict. + metadata (METADATA): Any metadata associated with the state dict. + is_coordinator (bool): Whether the machine this is running on is the coordinator of loading. + """ + if 'state' not in state_dict: + super().set_up_planner(state_dict, metadata, is_coordinator) + return + + self.original_state_dict = state_dict + + state_dict = dict(state_dict.items()) + state_dict['state'] = dict(state_dict['state'].items()) + target_modules = ['to_k', 'to_v', 'to_q', 'to_out.0'] + + for key in state_dict['state']['model'].keys(): + for mod in target_modules: + if f'{mod}.weight' in key: + new_key = key.replace(mod, mod + '.base_layer') + state_dict['state']['model'][new_key] = state_dict['state']['model'].pop(key) + break + + if self.flatten_sharded_tensors: + state_dict = _flatten_sharded_tensors(state_dict) + + if self.flatten_state_dict: + state_dict, self.mappings = flatten_state_dict(state_dict) + + self.state_dict = state_dict + self.metadata = metadata + self.is_coordinator = is_coordinator diff --git a/diffusion/train.py b/diffusion/train.py index becff0f1..3cc7392c 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -21,6 +21,7 @@ from diffusion.models.autoencoder import ComposerAutoEncoder, ComposerDiffusersAutoEncoder from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT +from diffusion.planners import LoraPlanner def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer: @@ -206,19 +207,28 @@ def train(config: DictConfig) -> None: print(f'Instantiating callbacks <{call_conf._target_}>') callbacks.append(hydra.utils.instantiate(call_conf)) + if 'fsdp_config' in config.trainer: + fsdp_config = dict(config.trainer.fsdp_config) + config.trainer.__delattr__("fsdp_config") + else: + fsdp_config = None + + if 'lora_rank' in config.model: + assert fsdp_config is not None + fsdp_config['load_planner'] = LoraPlanner + scheduler = hydra.utils.instantiate(config.scheduler) - trainer: Trainer = hydra.utils.instantiate( - config.trainer, - train_dataloader=train_dataloader, - eval_dataloader=eval_set, - optimizers=optimizer, - model=model, - loggers=logger, - algorithms=algorithms, - schedulers=scheduler, - callbacks=callbacks, - ) + trainer: Trainer = hydra.utils.instantiate(config.trainer, + train_dataloader=train_dataloader, + eval_dataloader=eval_set, + optimizers=optimizer, + model=model, + loggers=logger, + algorithms=algorithms, + schedulers=scheduler, + callbacks=callbacks, + fsdp_config=fsdp_config) def eval_and_then_train(): if config.get('eval_first', True):