From 25f9694fa17814adfa2b5e5c64c55ea90ffb4ebc Mon Sep 17 00:00:00 2001
From: rishab-partha <rishab.parthasarathy@gmail.com>
Date: Wed, 28 Aug 2024 20:27:24 +0000
Subject: [PATCH 1/4] LoRA Planner

---
 diffusion/planners/__init__.py     |  8 +++++
 diffusion/planners/lora_planner.py | 58 ++++++++++++++++++++++++++++++
 diffusion/train.py                 |  7 ++++
 3 files changed, 73 insertions(+)
 create mode 100644 diffusion/planners/__init__.py
 create mode 100644 diffusion/planners/lora_planner.py

diff --git a/diffusion/planners/__init__.py b/diffusion/planners/__init__.py
new file mode 100644
index 00000000..efafbb0a
--- /dev/null
+++ b/diffusion/planners/__init__.py
@@ -0,0 +1,8 @@
+# Copyright 2022 MosaicML Diffusion authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""Composer checkpointing planners."""
+
+from diffusion.planners.lora_planner import LoraPlanner
+
+__all__ = ['LoraPlanner']
diff --git a/diffusion/planners/lora_planner.py b/diffusion/planners/lora_planner.py
new file mode 100644
index 00000000..3f856bd1
--- /dev/null
+++ b/diffusion/planners/lora_planner.py
@@ -0,0 +1,58 @@
+# Copyright 2022 MosaicML Diffusion authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""LoRA Planner."""
+from torch.distributed.checkpoint._nested_dict import flatten_state_dict
+from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
+from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata
+
+__all__ = ['LoraPlanner']
+
+
+class LoraPlanner(DefaultLoadPlanner):
+    """Takes a Composer checkpoint and converts it to LoRA Checkpoint."""
+
+    def set_up_planner(
+        self,
+        state_dict: STATE_DICT_TYPE,
+        metadata: Metadata,
+        is_coordinator: bool,
+    ) -> None:
+        """Sets up the planner for converting Composer to LoRA Checkpoint.
+
+        Takes all targeted modules and checks whether they have been LoRA processed. If not,
+        changes names of weights appropriately. If yes, doesn't change anything for autoresume
+        compatibility.
+
+        Args:
+            state_dict (STATE_DICT_TYPE): Original torch state dict.
+            metadata (METADATA): Any metadata associated with the state dict.
+            is_coordinator (bool): Whether the machine this is running on is the coordinator of loading.
+        """
+        if 'state' not in state_dict:
+            super().set_up_planner(state_dict, metadata, is_coordinator)
+            return
+
+        self.original_state_dict = state_dict
+
+        state_dict = dict(state_dict.items())
+        state_dict['state'] = dict(state_dict['state'].items())
+        target_modules = ['to_k', 'to_v', 'to_q', 'to_out.0']
+
+        for key in state_dict['state']['model'].keys():
+            for mod in target_modules:
+                if f'{mod}.weight' in key:
+                    new_key = key.replace(mod, mod + '.base_layer')
+                    state_dict['state']['model'][new_key] = state_dict['state']['model'].pop(key)
+                    break
+
+        if self.flatten_sharded_tensors:
+            state_dict = _flatten_sharded_tensors(state_dict)
+
+        if self.flatten_state_dict:
+            state_dict, self.mappings = flatten_state_dict(state_dict)
+
+        self.state_dict = state_dict
+        self.metadata = metadata
+        self.is_coordinator = is_coordinator
diff --git a/diffusion/train.py b/diffusion/train.py
index becff0f1..dbac3f61 100644
--- a/diffusion/train.py
+++ b/diffusion/train.py
@@ -21,6 +21,7 @@
 
 from diffusion.models.autoencoder import ComposerAutoEncoder, ComposerDiffusersAutoEncoder
 from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT
+from diffusion.planners import LoraPlanner
 
 
 def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer:
@@ -206,6 +207,12 @@ def train(config: DictConfig) -> None:
                 print(f'Instantiating callbacks <{call_conf._target_}>')
                 callbacks.append(hydra.utils.instantiate(call_conf))
 
+    if 'planners' in config:
+        for pl_name, pl_conf in config.planners.items():
+            if pl_name == 'lora_planner' and pl_conf:
+                assert 'fsdp_config' in config.trainer
+                config.trainer.fsdp_config.load_planner = LoraPlanner
+
     scheduler = hydra.utils.instantiate(config.scheduler)
 
     trainer: Trainer = hydra.utils.instantiate(

From 9068d44dfe6a2e4b28af960a965590223f2694bc Mon Sep 17 00:00:00 2001
From: rishab-partha <rishab.parthasarathy@gmail.com>
Date: Wed, 28 Aug 2024 21:40:48 +0000
Subject: [PATCH 2/4] automatically use LoraPlanner for composer model

---
 diffusion/train.py | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/diffusion/train.py b/diffusion/train.py
index dbac3f61..e0d02604 100644
--- a/diffusion/train.py
+++ b/diffusion/train.py
@@ -207,11 +207,9 @@ def train(config: DictConfig) -> None:
                 print(f'Instantiating callbacks <{call_conf._target_}>')
                 callbacks.append(hydra.utils.instantiate(call_conf))
 
-    if 'planners' in config:
-        for pl_name, pl_conf in config.planners.items():
-            if pl_name == 'lora_planner' and pl_conf:
-                assert 'fsdp_config' in config.trainer
-                config.trainer.fsdp_config.load_planner = LoraPlanner
+    if 'lora_rank' in config.model:
+        assert 'fsdp_config' in config.trainer
+        config.trainer.fsdp_config.load_planner = LoraPlanner
 
     scheduler = hydra.utils.instantiate(config.scheduler)
 

From c14310d32008f51391df430bf358aadd30ec7363 Mon Sep 17 00:00:00 2001
From: rishab-partha <rishab.parthasarathy@gmail.com>
Date: Sat, 7 Sep 2024 00:42:28 +0000
Subject: [PATCH 3/4] fix hydra init

---
 diffusion/train.py | 31 ++++++++++++++++++-------------
 1 file changed, 18 insertions(+), 13 deletions(-)

diff --git a/diffusion/train.py b/diffusion/train.py
index e0d02604..59d7a66f 100644
--- a/diffusion/train.py
+++ b/diffusion/train.py
@@ -207,23 +207,28 @@ def train(config: DictConfig) -> None:
                 print(f'Instantiating callbacks <{call_conf._target_}>')
                 callbacks.append(hydra.utils.instantiate(call_conf))
 
+    if 'fsdp_config' in config.trainer:
+        fsdp_config = config.trainer.pop('fsdp_config')
+        fsdp_config = dict(fsdp_config)
+    else:
+        fsdp_config = None
+
     if 'lora_rank' in config.model:
-        assert 'fsdp_config' in config.trainer
-        config.trainer.fsdp_config.load_planner = LoraPlanner
+        assert fsdp_config is not None
+        fsdp_config['load_planner'] = LoraPlanner
 
     scheduler = hydra.utils.instantiate(config.scheduler)
 
-    trainer: Trainer = hydra.utils.instantiate(
-        config.trainer,
-        train_dataloader=train_dataloader,
-        eval_dataloader=eval_set,
-        optimizers=optimizer,
-        model=model,
-        loggers=logger,
-        algorithms=algorithms,
-        schedulers=scheduler,
-        callbacks=callbacks,
-    )
+    trainer: Trainer = hydra.utils.instantiate(config.trainer,
+                                               train_dataloader=train_dataloader,
+                                               eval_dataloader=eval_set,
+                                               optimizers=optimizer,
+                                               model=model,
+                                               loggers=logger,
+                                               algorithms=algorithms,
+                                               schedulers=scheduler,
+                                               callbacks=callbacks,
+                                               fsdp_config=fsdp_config)
 
     def eval_and_then_train():
         if config.get('eval_first', True):

From 6fe3ffba77bb8f41c35c10f73e13075d15fa7648 Mon Sep 17 00:00:00 2001
From: Rishab Parthasarathy <56666587+rishab-partha@users.noreply.github.com>
Date: Tue, 24 Sep 2024 12:07:57 -0700
Subject: [PATCH 4/4] try to fix?

---
 diffusion/train.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/diffusion/train.py b/diffusion/train.py
index 59d7a66f..3cc7392c 100644
--- a/diffusion/train.py
+++ b/diffusion/train.py
@@ -208,8 +208,8 @@ def train(config: DictConfig) -> None:
                 callbacks.append(hydra.utils.instantiate(call_conf))
 
     if 'fsdp_config' in config.trainer:
-        fsdp_config = config.trainer.pop('fsdp_config')
-        fsdp_config = dict(fsdp_config)
+        fsdp_config = dict(config.trainer.fsdp_config)
+        config.trainer.__delattr__("fsdp_config")
     else:
         fsdp_config = None