Skip to content

Commit

Permalink
refactor: remove unused prepare_split_model method from FluxNetworkTr…
Browse files Browse the repository at this point in the history
…ainer
  • Loading branch information
kohya-ss committed Nov 14, 2024
1 parent 2bb0f54 commit 5c5b544
Showing 1 changed file with 0 additions and 39 deletions.
39 changes: 0 additions & 39 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,45 +127,6 @@ def load_target_model(self, args, weight_dtype, accelerator):

return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model

"""
def prepare_split_model(self, model, weight_dtype, accelerator):
from accelerate import init_empty_weights
logger.info("prepare split model")
with init_empty_weights():
flux_upper = flux_models.FluxUpper(model.params)
flux_lower = flux_models.FluxLower(model.params)
sd = model.state_dict()
# lower (trainable)
logger.info("load state dict for lower")
flux_lower.load_state_dict(sd, strict=False, assign=True)
flux_lower.to(dtype=weight_dtype)
# upper (frozen)
logger.info("load state dict for upper")
flux_upper.load_state_dict(sd, strict=False, assign=True)
logger.info("prepare upper model")
target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype
flux_upper.to(accelerator.device, dtype=target_dtype)
flux_upper.eval()
if args.fp8_base:
# this is required to run on fp8
flux_upper = accelerator.prepare(flux_upper)
flux_upper.to("cpu")
self.flux_upper = flux_upper
del model # we don't need model anymore
clean_memory_on_device(accelerator.device)
logger.info("split model prepared")
return flux_lower
"""

def get_tokenize_strategy(self, args):
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)

Expand Down

0 comments on commit 5c5b544

Please sign in to comment.