Skip to content

Commit

Permalink
Merge branch 'main' into feat/ci-benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Dec 5, 2023
2 parents b394168 + f427345 commit 70f3556
Show file tree
Hide file tree
Showing 13 changed files with 376 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,26 @@ def save_model_card(
"""

trigger_str = f"You should use {instance_prompt} to trigger the image generation."
diffusers_imports_pivotal = ""
diffusers_example_pivotal = ""
if train_text_encoder_ti:
trigger_str = (
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
"in you prompt with the new inserted tokens:\n"
)
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
"""
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
"""
if token_abstraction_dict:
for key, value in token_abstraction_dict.items():
tokens = "".join(value)
trigger_str += f"""
to trigger concept `{key}->` use `{tokens}` in your prompt \n
to trigger concept `{key}` → use `{tokens}` in your prompt \n
"""

yaml = f"""
Expand Down Expand Up @@ -172,7 +182,21 @@ def save_model_card(
{trigger_str}
## Download model
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch
{diffusers_imports_pivotal}
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
{diffusers_example_pivotal}
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke)
Weights for this model are available in Safetensors format.
Expand Down Expand Up @@ -791,6 +815,12 @@ def __init__(
instance_data_root,
instance_prompt,
class_prompt,
dataset_name,
dataset_config_name,
cache_dir,
image_column,
caption_column,
train_text_encoder_ti,
class_data_root=None,
class_num=None,
token_abstraction_dict=None, # token mapping for textual inversion
Expand All @@ -805,10 +835,10 @@ def __init__(
self.custom_instance_prompts = None
self.class_prompt = class_prompt
self.token_abstraction_dict = token_abstraction_dict

self.train_text_encoder_ti = train_text_encoder_ti
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
if args.dataset_name is not None:
if dataset_name is not None:
try:
from datasets import load_dataset
except ImportError:
Expand All @@ -821,38 +851,37 @@ def __init__(
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
dataset_name,
dataset_config_name,
cache_dir=cache_dir,
)
# Preprocessing the datasets.
column_names = dataset["train"].column_names

# 6. Get the column names for input/target.
if args.image_column is None:
if image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]

if args.caption_column is None:
if caption_column is None:
logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the "
"column as --caption_column"
)
self.custom_instance_prompts = None
else:
if args.caption_column not in column_names:
if caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
custom_instance_prompts = dataset["train"][args.caption_column]
custom_instance_prompts = dataset["train"][caption_column]
# create final list of captions according to --repeats
self.custom_instance_prompts = []
for caption in custom_instance_prompts:
Expand Down Expand Up @@ -907,7 +936,7 @@ def __getitem__(self, index):
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
if caption:
if args.train_text_encoder_ti:
if self.train_text_encoder_ti:
# replace instances of --token_abstraction in caption with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in self.token_abstraction_dict.items():
caption = caption.replace(token_abs, "".join(token_replacement))
Expand Down Expand Up @@ -1093,10 +1122,10 @@ def main(args):
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

model_id = args.hub_model_id or Path(args.output_dir).name
repo_id = None
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id

# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -1464,6 +1493,12 @@ def load_model_hook(models, input_dir):
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_prompt=args.class_prompt,
dataset_name=args.dataset_name,
dataset_config_name=args.dataset_config_name,
cache_dir=args.cache_dir,
image_column=args.image_column,
train_text_encoder_ti=args.train_text_encoder_ti,
caption_column=args.caption_column,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
class_num=args.num_class_images,
Expand Down Expand Up @@ -2004,23 +2039,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
}
)

if args.push_to_hub:
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/embeddings.safetensors",
)
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
train_text_encoder_ti=args.train_text_encoder_ti,
token_abstraction_dict=train_dataset.token_abstraction_dict,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/embeddings.safetensors",
)
save_model_card(
model_id if not args.push_to_hub else repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
train_text_encoder_ti=args.train_text_encoder_ti,
token_abstraction_dict=train_dataset.token_abstraction_dict,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
)
if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
)

if torch_dtype is not None:
pipe.to(torch_dtype=torch_dtype)
pipe.to(dtype=torch_dtype)

return pipe

Expand Down
Loading

0 comments on commit 70f3556

Please sign in to comment.