diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 14ff7c240..913b1d435 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -181,7 +181,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device)