diff --git a/flux_train.py b/flux_train.py index d2a9b3f32..56b057085 100644 --- a/flux_train.py +++ b/flux_train.py @@ -310,7 +310,7 @@ def train(args): logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") else: - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize, model=flux) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset