Skip to content

Commit

Permalink
Fix "push_to_hub only create repo in consistency model lora SDXL trai…
Browse files Browse the repository at this point in the history
…ning script" (huggingface#6102)

* fix

* style fix

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
2 people authored and donhardman committed Dec 29, 2023
1 parent 9a1a1de commit 6893b27
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 8 deletions.
12 changes: 10 additions & 2 deletions examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from braceexpand import braceexpand
from huggingface_hub import create_repo
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
from torch.utils.data import default_collate
Expand Down Expand Up @@ -847,7 +847,7 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
create_repo(
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
token=args.hub_token,
Expand Down Expand Up @@ -1366,6 +1366,14 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)

if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)

accelerator.end_training()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from braceexpand import braceexpand
from huggingface_hub import create_repo
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
from torch.utils.data import default_collate
Expand Down Expand Up @@ -842,7 +842,7 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
create_repo(
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
token=args.hub_token,
Expand Down Expand Up @@ -1424,6 +1424,14 @@ def compute_embeddings(
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)

if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)

accelerator.end_training()


Expand Down
12 changes: 10 additions & 2 deletions examples/consistency_distillation/train_lcm_distill_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from braceexpand import braceexpand
from huggingface_hub import create_repo
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torch.utils.data import default_collate
from torchvision import transforms
Expand Down Expand Up @@ -835,7 +835,7 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
create_repo(
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
token=args.hub_token,
Expand Down Expand Up @@ -1354,6 +1354,14 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
target_unet = accelerator.unwrap_model(target_unet)
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))

if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)

accelerator.end_training()


Expand Down
12 changes: 10 additions & 2 deletions examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from braceexpand import braceexpand
from huggingface_hub import create_repo
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torch.utils.data import default_collate
from torchvision import transforms
Expand Down Expand Up @@ -875,7 +875,7 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
create_repo(
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
token=args.hub_token,
Expand Down Expand Up @@ -1457,6 +1457,14 @@ def compute_embeddings(
target_unet = accelerator.unwrap_model(target_unet)
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))

if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)

accelerator.end_training()


Expand Down

0 comments on commit 6893b27

Please sign in to comment.