From 25f9694fa17814adfa2b5e5c64c55ea90ffb4ebc Mon Sep 17 00:00:00 2001 From: rishab-partha <rishab.parthasarathy@gmail.com> Date: Wed, 28 Aug 2024 20:27:24 +0000 Subject: [PATCH 1/4] LoRA Planner --- diffusion/planners/__init__.py | 8 +++++ diffusion/planners/lora_planner.py | 58 ++++++++++++++++++++++++++++++ diffusion/train.py | 7 ++++ 3 files changed, 73 insertions(+) create mode 100644 diffusion/planners/__init__.py create mode 100644 diffusion/planners/lora_planner.py diff --git a/diffusion/planners/__init__.py b/diffusion/planners/__init__.py new file mode 100644 index 00000000..efafbb0a --- /dev/null +++ b/diffusion/planners/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Composer checkpointing planners.""" + +from diffusion.planners.lora_planner import LoraPlanner + +__all__ = ['LoraPlanner'] diff --git a/diffusion/planners/lora_planner.py b/diffusion/planners/lora_planner.py new file mode 100644 index 00000000..3f856bd1 --- /dev/null +++ b/diffusion/planners/lora_planner.py @@ -0,0 +1,58 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""LoRA Planner.""" +from torch.distributed.checkpoint._nested_dict import flatten_state_dict +from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata + +__all__ = ['LoraPlanner'] + + +class LoraPlanner(DefaultLoadPlanner): + """Takes a Composer checkpoint and converts it to LoRA Checkpoint.""" + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Metadata, + is_coordinator: bool, + ) -> None: + """Sets up the planner for converting Composer to LoRA Checkpoint. + + Takes all targeted modules and checks whether they have been LoRA processed. If not, + changes names of weights appropriately. If yes, doesn't change anything for autoresume + compatibility. + + Args: + state_dict (STATE_DICT_TYPE): Original torch state dict. + metadata (METADATA): Any metadata associated with the state dict. + is_coordinator (bool): Whether the machine this is running on is the coordinator of loading. + """ + if 'state' not in state_dict: + super().set_up_planner(state_dict, metadata, is_coordinator) + return + + self.original_state_dict = state_dict + + state_dict = dict(state_dict.items()) + state_dict['state'] = dict(state_dict['state'].items()) + target_modules = ['to_k', 'to_v', 'to_q', 'to_out.0'] + + for key in state_dict['state']['model'].keys(): + for mod in target_modules: + if f'{mod}.weight' in key: + new_key = key.replace(mod, mod + '.base_layer') + state_dict['state']['model'][new_key] = state_dict['state']['model'].pop(key) + break + + if self.flatten_sharded_tensors: + state_dict = _flatten_sharded_tensors(state_dict) + + if self.flatten_state_dict: + state_dict, self.mappings = flatten_state_dict(state_dict) + + self.state_dict = state_dict + self.metadata = metadata + self.is_coordinator = is_coordinator diff --git a/diffusion/train.py b/diffusion/train.py index becff0f1..dbac3f61 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -21,6 +21,7 @@ from diffusion.models.autoencoder import ComposerAutoEncoder, ComposerDiffusersAutoEncoder from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT +from diffusion.planners import LoraPlanner def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer: @@ -206,6 +207,12 @@ def train(config: DictConfig) -> None: print(f'Instantiating callbacks <{call_conf._target_}>') callbacks.append(hydra.utils.instantiate(call_conf)) + if 'planners' in config: + for pl_name, pl_conf in config.planners.items(): + if pl_name == 'lora_planner' and pl_conf: + assert 'fsdp_config' in config.trainer + config.trainer.fsdp_config.load_planner = LoraPlanner + scheduler = hydra.utils.instantiate(config.scheduler) trainer: Trainer = hydra.utils.instantiate( From 9068d44dfe6a2e4b28af960a965590223f2694bc Mon Sep 17 00:00:00 2001 From: rishab-partha <rishab.parthasarathy@gmail.com> Date: Wed, 28 Aug 2024 21:40:48 +0000 Subject: [PATCH 2/4] automatically use LoraPlanner for composer model --- diffusion/train.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/diffusion/train.py b/diffusion/train.py index dbac3f61..e0d02604 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -207,11 +207,9 @@ def train(config: DictConfig) -> None: print(f'Instantiating callbacks <{call_conf._target_}>') callbacks.append(hydra.utils.instantiate(call_conf)) - if 'planners' in config: - for pl_name, pl_conf in config.planners.items(): - if pl_name == 'lora_planner' and pl_conf: - assert 'fsdp_config' in config.trainer - config.trainer.fsdp_config.load_planner = LoraPlanner + if 'lora_rank' in config.model: + assert 'fsdp_config' in config.trainer + config.trainer.fsdp_config.load_planner = LoraPlanner scheduler = hydra.utils.instantiate(config.scheduler) From c14310d32008f51391df430bf358aadd30ec7363 Mon Sep 17 00:00:00 2001 From: rishab-partha <rishab.parthasarathy@gmail.com> Date: Sat, 7 Sep 2024 00:42:28 +0000 Subject: [PATCH 3/4] fix hydra init --- diffusion/train.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/diffusion/train.py b/diffusion/train.py index e0d02604..59d7a66f 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -207,23 +207,28 @@ def train(config: DictConfig) -> None: print(f'Instantiating callbacks <{call_conf._target_}>') callbacks.append(hydra.utils.instantiate(call_conf)) + if 'fsdp_config' in config.trainer: + fsdp_config = config.trainer.pop('fsdp_config') + fsdp_config = dict(fsdp_config) + else: + fsdp_config = None + if 'lora_rank' in config.model: - assert 'fsdp_config' in config.trainer - config.trainer.fsdp_config.load_planner = LoraPlanner + assert fsdp_config is not None + fsdp_config['load_planner'] = LoraPlanner scheduler = hydra.utils.instantiate(config.scheduler) - trainer: Trainer = hydra.utils.instantiate( - config.trainer, - train_dataloader=train_dataloader, - eval_dataloader=eval_set, - optimizers=optimizer, - model=model, - loggers=logger, - algorithms=algorithms, - schedulers=scheduler, - callbacks=callbacks, - ) + trainer: Trainer = hydra.utils.instantiate(config.trainer, + train_dataloader=train_dataloader, + eval_dataloader=eval_set, + optimizers=optimizer, + model=model, + loggers=logger, + algorithms=algorithms, + schedulers=scheduler, + callbacks=callbacks, + fsdp_config=fsdp_config) def eval_and_then_train(): if config.get('eval_first', True): From 6fe3ffba77bb8f41c35c10f73e13075d15fa7648 Mon Sep 17 00:00:00 2001 From: Rishab Parthasarathy <56666587+rishab-partha@users.noreply.github.com> Date: Tue, 24 Sep 2024 12:07:57 -0700 Subject: [PATCH 4/4] try to fix? --- diffusion/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffusion/train.py b/diffusion/train.py index 59d7a66f..3cc7392c 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -208,8 +208,8 @@ def train(config: DictConfig) -> None: callbacks.append(hydra.utils.instantiate(call_conf)) if 'fsdp_config' in config.trainer: - fsdp_config = config.trainer.pop('fsdp_config') - fsdp_config = dict(fsdp_config) + fsdp_config = dict(config.trainer.fsdp_config) + config.trainer.__delattr__("fsdp_config") else: fsdp_config = None