diff --git a/flux_train_network.py b/flux_train_network.py index 9bcd59282..704c4d32e 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -127,45 +127,6 @@ def load_target_model(self, args, weight_dtype, accelerator): return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model - """ - def prepare_split_model(self, model, weight_dtype, accelerator): - from accelerate import init_empty_weights - - logger.info("prepare split model") - with init_empty_weights(): - flux_upper = flux_models.FluxUpper(model.params) - flux_lower = flux_models.FluxLower(model.params) - sd = model.state_dict() - - # lower (trainable) - logger.info("load state dict for lower") - flux_lower.load_state_dict(sd, strict=False, assign=True) - flux_lower.to(dtype=weight_dtype) - - # upper (frozen) - logger.info("load state dict for upper") - flux_upper.load_state_dict(sd, strict=False, assign=True) - - logger.info("prepare upper model") - target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype - flux_upper.to(accelerator.device, dtype=target_dtype) - flux_upper.eval() - - if args.fp8_base: - # this is required to run on fp8 - flux_upper = accelerator.prepare(flux_upper) - - flux_upper.to("cpu") - - self.flux_upper = flux_upper - del model # we don't need model anymore - clean_memory_on_device(accelerator.device) - - logger.info("split model prepared") - - return flux_lower - """ - def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)