From 6e2f11d3deb4b943f80068c1d863cecf3a3e1902 Mon Sep 17 00:00:00 2001 From: freddy_jiao Date: Fri, 20 Dec 2024 20:24:07 +0800 Subject: [PATCH] upload code --- Img-Diff | 1 - Img-Diff-codes/README.md | 84 +++ .../object_removal/generate_inpaint.py | 532 ++++++++++++++++ .../object_removal/run_generate_inpaint.sh | 11 + .../object_replacement/cos_count.py | 63 ++ .../object_replacement/cos_filter.py | 83 +++ .../object_replacement/cos_filter.sh | 4 + .../object_replacement/generate_bbox.py | 366 +++++++++++ .../object_replacement/generate_bbox.sh | 7 + .../generate_final_data_new_edit.py | 372 +++++++++++ .../generate_final_data_new_edit.sh | 11 + Img-Diff-codes/pairs_generator/gen.py | 99 +++ Img-Diff-codes/pairs_generator/gen.sh | 5 + .../pairs_generator/gen_new_data_ddp.py | 81 +++ .../pairs_generator/gen_sdxl_new_data_ddp.sh | 4 + Img-Diff-codes/pairs_generator/processors.py | 596 ++++++++++++++++++ .../prompt_to_prompt_pipeline.py | 473 ++++++++++++++ README.md | 2 +- 18 files changed, 2792 insertions(+), 2 deletions(-) delete mode 160000 Img-Diff create mode 100644 Img-Diff-codes/README.md create mode 100644 Img-Diff-codes/object_removal/generate_inpaint.py create mode 100644 Img-Diff-codes/object_removal/run_generate_inpaint.sh create mode 100644 Img-Diff-codes/object_replacement/cos_count.py create mode 100644 Img-Diff-codes/object_replacement/cos_filter.py create mode 100644 Img-Diff-codes/object_replacement/cos_filter.sh create mode 100644 Img-Diff-codes/object_replacement/generate_bbox.py create mode 100644 Img-Diff-codes/object_replacement/generate_bbox.sh create mode 100644 Img-Diff-codes/object_replacement/generate_final_data_new_edit.py create mode 100644 Img-Diff-codes/object_replacement/generate_final_data_new_edit.sh create mode 100644 Img-Diff-codes/pairs_generator/gen.py create mode 100644 Img-Diff-codes/pairs_generator/gen.sh create mode 100644 Img-Diff-codes/pairs_generator/gen_new_data_ddp.py create mode 100644 Img-Diff-codes/pairs_generator/gen_sdxl_new_data_ddp.sh create mode 100644 Img-Diff-codes/pairs_generator/processors.py create mode 100644 Img-Diff-codes/pairs_generator/prompt_to_prompt_pipeline.py diff --git a/Img-Diff b/Img-Diff deleted file mode 160000 index 34538b4d4..000000000 --- a/Img-Diff +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 34538b4d4cba56b8ec6d14009317921ad39a37ec diff --git a/Img-Diff-codes/README.md b/Img-Diff-codes/README.md new file mode 100644 index 000000000..fc98c069e --- /dev/null +++ b/Img-Diff-codes/README.md @@ -0,0 +1,84 @@ +# Img-Diff: Contrastive Data Syhthesis for Multimodal Large Language Models + + +## Environment + +``` +transformers==4.36.2 +``` + +For the other requirements, please refer to [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main) and [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt/). + + + + + +## Image Pairs Generator + +### step1 : generate caption pairs + +```shell +# Img_Diff/pairs_generator/ +$ bash gen.sh +``` + + + +### step2 : generate image pairs + +```shell +# Img_Diff/pairs_generator/ +$ bash gen_sdxl_new_data_ddp.sh +``` + + + + + +## Object Replacement Data Generator + +### step1 : calculate image similarity + +```shell +# Img_Diff/object_replacement/ +$ bash cos_filter.sh +``` + + + +### step2 : image similarity filter + +```shell +# Img_Diff/object_replacement/ +$ python cos_count.py +``` + + + +### step3 : generate difference area + +```shell +# Img_Diff/object_replacement/ +$ bash generate_bbox.sh +``` + + + +### step4 : generate difference captions + +```shell +# Img_Diff/object_replacement/ +$ bash generate_final_data_new_edit.sh +``` + + + + + +## Object Removal Data Generator + +```shell +# Img_Diff/object_removal/ +$ bash run_generate_inpaint.sh +``` + diff --git a/Img-Diff-codes/object_removal/generate_inpaint.py b/Img-Diff-codes/object_removal/generate_inpaint.py new file mode 100644 index 000000000..e581a704f --- /dev/null +++ b/Img-Diff-codes/object_removal/generate_inpaint.py @@ -0,0 +1,532 @@ +import sys +sys.path.append("../object_replacement/LLaVA/") + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path + +from ultralytics import FastSAM, YOLO +from transformers import CLIPImageProcessor, CLIPModel +import json +import numpy as np +import tqdm +import cv2 +from PIL import Image, ImageDraw, ImageColor +import torch +from torch.utils.data import Dataset, DataLoader +import difflib +from transformers import BlipProcessor, BlipForImageTextRetrieval +import re +from diffusers import AutoPipelineForInpainting +from diffusers.utils import load_image, make_image_grid +import random +import argparse +import os +import nltk +from nltk.corpus import wordnet as wn +from nltk import pos_tag +from nltk.tokenize import word_tokenize + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--out_path', type=str, default="./inpaint") + parser.add_argument('--vit_path', type=str, default="clip-vit-base-patch32") + parser.add_argument('--blip_path', type=str, default="blip-itm-large-coco") + parser.add_argument('--fastsam_path', type=str, default="FastSAM-x.pt") + parser.add_argument('--json_path', type=str, default="./filtered_file_new_edit_09_098_3.json") + parser.add_argument('--sd_model_path', type=str, default="stable-diffusion-xl-base-1.0") + parser.add_argument('--split_name', type=str, default="0") + parser.add_argument("--output_file", type=str, default="inpaint_0.json") + parser.add_argument('--mllm_path', type=str, default="./llava-v1.6-vicuna-7b") + args=parser.parse_args() + + return args + + +def iou_filter(samples, iou_thresh): + x1 = samples[:, 0] + y1 = samples[:, 1] + x2 = samples[:, 2] + y2 = samples[:, 3] + scores = samples[:, 4] + + areas = (y2 - y1 + 1) * (x2 - x1 + 1) + keep_boxes = [] + index = scores.argsort() # Ascending + + while len(index) > 0: + i = index[0] + keep_boxes.append(i) + + x1_overlap = np.maximum(x1[i], x1[index[1:]]) + y1_overlap = np.maximum(y1[i], y1[index[1:]]) + x2_overlap = np.minimum(x2[i], x2[index[1:]]) + y2_overlap = np.minimum(y2[i], y2[index[1:]]) + + w = np.maximum(0, x2_overlap - x1_overlap + 1) + h = np.maximum(0, y2_overlap - y1_overlap + 1) + overlap_area = w * h + + ious = overlap_area / (areas[i] + areas[index[1:]] - overlap_area) + + idx = np.where(ious <= iou_thresh)[0] + index = index[idx + 1] # update + + return samples[keep_boxes] + + +class InferenceDataset_for_FastSAM(Dataset): + + def __init__(self, json_path): + with open(json_path, "r") as f: + self.image_path = json.load(f) + + def __len__(self) -> int: + return len(self.image_path) + + @torch.no_grad() + def __getitem__(self, idx: int): + # image_array1 = cv2.cvtColor(cv2.imread(self.image_path[idx] + "_0.jpg"), cv2.COLOR_BGR2RGB) + # image_array2 = cv2.cvtColor(cv2.imread(self.image_path[idx] + "_1.jpg"), cv2.COLOR_BGR2RGB) + # image_array2 = cv2.resize(image_array2,(image_array1.shape[1],image_array1.shape[0])) + + return self.image_path[idx] + + + +class InferenceDataset_for_clip(Dataset): + + def __init__(self, image_list): + self.image_list = image_list + + def __len__(self) -> int: + return len(self.image_list) + + @torch.no_grad() + def __getitem__(self, idx: int): + + return self.image_list[idx] + + +class InferenceDataset_for_blip(Dataset): + + def __init__(self, pixel_values): + self.pixel_values = pixel_values + + def __len__(self) -> int: + return len(self.pixel_values) + + @torch.no_grad() + def __getitem__(self, idx: int): + + return self.pixel_values[idx] + + +def is_noun(word): + + pos_tagged = pos_tag([word]) + pos = pos_tagged[0][1] + + return pos in ['NN', 'NNS', 'NNP', 'NNPS'] + + +def compare_text_index(text1, text2): + # matcher = difflib.SequenceMatcher(a=text1, b=text2) + # diff_report = matcher.get_opcodes() + + # for tag, i1, i2, j1, j2 in diff_report: + # if tag == 'replace': + # return text1[i1:i2], text2[j1:j2] + + + d = difflib.Differ() + diff = d.compare(re.sub(r'[^\w\s]', '', text1.lower().replace(" ", "\n")).splitlines(), re.sub(r'[^\w\s]', '', text2.lower().replace(" ", "\n")).splitlines()) + + text1 = "" + text2 = "" + + for line in diff: + if line.startswith('+'): + if not is_noun(line.replace("+ ", "")): + continue + text1 = text1 + " " + line.replace("+ ", "") + elif line.startswith('-'): + if not is_noun(line.replace("- ", "")): + continue + text2 = text2 + " " + line.replace("- ", "") + + + return text1.strip(), text2.strip() + + + + +if __name__ == "__main__": + args = parse_args() + + llava_path = args.mllm_path + model_path = os.path.expanduser(llava_path) + model_base = None + model_name = get_model_name_from_path(model_path) + tokenizer, mllm, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, use_flash_attn=False, load_4bit=False) + + choice = ["A", "B"] + rand_choice = [0, 1] + new_json = [] + device = "cuda" + + vision_model = CLIPModel.from_pretrained(args.vit_path).to(device).half() + processor = CLIPImageProcessor.from_pretrained(args.vit_path) + + blip_processor = BlipProcessor.from_pretrained(args.blip_path) + blip_model = BlipForImageTextRetrieval.from_pretrained(args.blip_path, torch_dtype=torch.float16).to(device).half() + + fastSAM_model = FastSAM(args.fastsam_path) + + image_dataset = InferenceDataset_for_FastSAM(args.json_path) + print(args.json_path) + dataloader_fastsam = DataLoader(image_dataset, batch_size=16, drop_last=False) + + pipeline = AutoPipelineForInpainting.from_pretrained(args.sd_model_path, torch_dtype=torch.float16).to("cuda") + + with torch.no_grad(): + for image_path_list in tqdm.tqdm(dataloader_fastsam): + # print(len(image_path_list)) + image_list1 = [] + image_list2 = [] + for temp_idx_list in range(len(image_path_list["path"])): + image_array1 = cv2.cvtColor(cv2.imread(image_path_list["path"][temp_idx_list].replace("./", "../new_edit_data/") + "_0.jpg"), cv2.COLOR_BGR2RGB) + image_array1 = cv2.resize(image_array1, (512, 512)) + image_list1.append(image_array1) + + image_array2 = cv2.cvtColor(cv2.imread(image_path_list["path"][temp_idx_list].replace("./", "../new_edit_data/") + "_1.jpg"), cv2.COLOR_BGR2RGB) + image_array2 = cv2.resize(image_array2,(512, 512)) + image_list2.append(image_array2) + + masks1 = fastSAM_model(image_list1, retina_masks=True, imgsz=1024, conf=0.1, iou=0.5, verbose=False) + masks2 = fastSAM_model(image_list2, retina_masks=True, imgsz=1024, conf=0.1, iou=0.5, verbose=False) + + for temp_idx_mask in range(len(image_path_list["path"])): + + # print(image_path_list["input"][temp_idx_mask]) + # print(image_path_list["output"][temp_idx_mask]) + + noun1, noun2 = compare_text_index(image_path_list["input"][temp_idx_mask], image_path_list["output"][temp_idx_mask]) + if noun1 == "" or noun2 == "": + continue + + # print(image_path_list["input"][temp_idx_mask]) + # print(image_path_list["output"][temp_idx_mask]) + # print(noun1) + # print(noun2) + + temp_mask1 = masks1[temp_idx_mask] + temp_mask2 = masks2[temp_idx_mask] + if len(temp_mask1.boxes.xyxy) == 0: + continue + + image_array1 = image_list1[temp_idx_mask] + image_array2 = image_list2[temp_idx_mask] + + image_targets = [] + image_targets_pos = [] + diff_targets_1 = [] + diff_targets_2 = [] + mask_list = [] + + with torch.no_grad(): + for temp_target, temp_ori_mask in zip(temp_mask1.boxes.xyxy, temp_mask1.masks): + # crop_img = image_array1[int(temp_target['bbox'][1]):int(temp_target['bbox'][1])+int(temp_target['bbox'][3]),int(temp_target['bbox'][0]):int(temp_target['bbox'][0])+int(temp_target['bbox'][2]),:] + crop_img = image_array1[int(temp_target[1]):int(temp_target[3]),int(temp_target[0]):int(temp_target[2]),:] + img = Image.fromarray(crop_img) + image_targets.append(img) + # image1_targets_pos.append(temp_target['bbox']) + image_targets_pos.append(temp_target.cpu()) + + cv2_img = np.where(temp_ori_mask.data.cpu().numpy().transpose((1,0,2)).transpose((0,2,1)) > 0.5, 255, 0).astype(np.uint8) + cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_GRAY2RGB) + pil_img = Image.fromarray(cv2_img) + mask_list.append(pil_img) + + crop_img_same_pos = image_array2[int(temp_target[1]):int(temp_target[3]),int(temp_target[0]):int(temp_target[2]),:] + img = Image.fromarray(crop_img_same_pos) + image_targets.append(img) + + num_image1_targets = len(image_targets) + + # cv2.imwrite(os.path.join(args.out_path, image_path_list["path"][temp_idx_mask].split("/")[-1] + "_img0.jpg"), image_array1) + # cv2.imwrite(os.path.join(args.out_path, image_path_list["path"][temp_idx_mask].split("/")[-1] + "_img1.jpg"), image_array2) + + for temp_target, temp_ori_mask in zip(temp_mask2.boxes.xyxy, temp_mask2.masks): + # crop_img = image_array1[int(temp_target['bbox'][1]):int(temp_target['bbox'][1])+int(temp_target['bbox'][3]),int(temp_target['bbox'][0]):int(temp_target['bbox'][0])+int(temp_target['bbox'][2]),:] + crop_img = image_array2[int(temp_target[1]):int(temp_target[3]),int(temp_target[0]):int(temp_target[2]),:] + img = Image.fromarray(crop_img) + image_targets.append(img) + # image1_targets_pos.append(temp_target['bbox']) + image_targets_pos.append(temp_target.cpu()) + + cv2_img = np.where(temp_ori_mask.data.cpu().numpy().transpose((1,0,2)).transpose((0,2,1)) > 0.5, 255, 0).astype(np.uint8) + cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_GRAY2RGB) + pil_img = Image.fromarray(cv2_img) + mask_list.append(pil_img) + + crop_img_same_pos = image_array1[int(temp_target[1]):int(temp_target[3]),int(temp_target[0]):int(temp_target[2]),:] + img = Image.fromarray(crop_img_same_pos) + image_targets.append(img) + + # print(len(image_targets)) + + if len(image_targets) == 0: + continue + try: + image_targets_clip = processor(image_targets, return_tensors="pt")['pixel_values'].half() + except: + continue + image_dataset = InferenceDataset_for_clip(image_targets_clip) + dataloader_clip = DataLoader(image_dataset, batch_size=256, drop_last=False) + + image_feature = None + for batch in dataloader_clip: + temp_image_feature = vision_model.get_image_features(batch.to(vision_model.device)) + if image_feature == None: + image_feature = temp_image_feature.to(torch.float32) + else: + image_feature = torch.cat((image_feature, temp_image_feature.to(torch.float32)), dim = 0) + + + try: + image_targets_blip = blip_processor(image_targets, [noun1, noun2], return_tensors="pt", padding=True).to(device, torch.float16) + except: + continue + + input_ids = image_targets_blip['input_ids'] + attention_mask = image_targets_blip['attention_mask'] + image_dataset = InferenceDataset_for_blip(image_targets_blip['pixel_values']) + dataloader_blip = DataLoader(image_dataset, batch_size=256, drop_last=False) + blip_itm_score = None + for pixel_values in dataloader_blip: + cosine_score = blip_model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, use_itm_head=False).itm_score + if blip_itm_score == None: + blip_itm_score = cosine_score.to(torch.float32) + else: + blip_itm_score = torch.cat((blip_itm_score, cosine_score.to(torch.float32)), dim = 0) + + # print(blip_itm_score) + + for temp_idx_cos in range(0, image_feature.shape[0], 2): + + # thresh = 0.3 + # if blip_itm_score[temp_idx_cos][0] < thresh and blip_itm_score[temp_idx_cos + 1][1] < thresh: + # continue + + + + cos = torch.cosine_similarity(image_feature[temp_idx_cos], image_feature[temp_idx_cos + 1], dim=0) + if (cos<0.9): + temp_diff_target = [] + temp_diff_target.extend(image_targets_pos[int(temp_idx_cos/2)]) + temp_diff_target.append(cos.cpu()) + + if temp_idx_cos < num_image1_targets: + temp_diff_target.append(1) + else: + temp_diff_target.append(2) + + temp_diff_target.append(int(temp_idx_cos/2)) + + if temp_idx_cos < num_image1_targets: + diff_targets_1.append(temp_diff_target) + else: + diff_targets_2.append(temp_diff_target) + + + # print(len(diff_targets)) + + if len(diff_targets_1) + len(diff_targets_2) == 0: + continue + + if len(diff_targets_1) > 0 : + filtered_targets_1 = iou_filter(np.array(diff_targets_1), 0.5) + + for temp_idx, temp_filtered_target in enumerate(filtered_targets_1): + + if temp_idx == 3: + break + + bbox_x1 = temp_filtered_target[0] + bbox_y1 = temp_filtered_target[1] + bbox_x2 = temp_filtered_target[2] + bbox_y2 = temp_filtered_target[3] + + which_img = temp_filtered_target[5] + mask = mask_list[int(temp_filtered_target[6])] + + if which_img == 1: + + # image 1 + prompt = "background, nothing, 8k" + new_image = pipeline(prompt=prompt, image=Image.fromarray(image_array1), mask_image=mask, strength=0.85, guidance_scale=0, num_inference_steps=4).images[0] + temp_image_array1 = Image.fromarray(image_array1) + + # mllm captioning + prompt_bbox = f"Please provide a clear description for this region: [{str(bbox_x1)}, {str(bbox_y1)}, {str(bbox_x2)}, {str(bbox_y2)}]." + prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt_bbox + conv = conv_templates["vicuna_v1"].copy() + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') + input_ids = input_ids.to(device=mllm.device, non_blocking=True) + image_tensor = process_images([temp_image_array1], image_processor, mllm.config)[0] + + temperature = 0 + with torch.inference_mode(): + output_ids = mllm.generate( + input_ids.unsqueeze(0), + images=image_tensor.to(dtype=torch.float16, device=mllm.device, non_blocking=True).unsqueeze(0), + image_sizes=[temp_image_array1.size], + do_sample=True if temperature > 0 else False, + temperature=temperature, + top_p=None, + num_beams=1, + max_new_tokens=64, + use_cache=True) # + caption = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] + + # print(caption) + + # caption quality filter + crop_pil_img1 = temp_image_array1.crop((bbox_x1, bbox_y1, bbox_x2, bbox_y2)) + crop_pil_img2 = new_image.crop((bbox_x1, bbox_y1, bbox_x2, bbox_y2)) + crop_inputs_list = blip_processor([crop_pil_img1, crop_pil_img2], [caption], return_tensors="pt", padding=True).to("cuda", torch.float16) + cosine_score = blip_model(**crop_inputs_list, use_itm_head=False).itm_score + if cosine_score[0][0] < 0.35 or cosine_score[1][0] > 0.2: + continue + + + now_choice = random.choice(rand_choice) + draw = ImageDraw.ImageDraw(temp_image_array1) + draw.rectangle(((bbox_x1-15, bbox_y1-15),(bbox_x2+15, bbox_y2+15)), fill=None, outline='red', width=3) + temp_image_array1.save(os.path.join(args.out_path, image_path_list["path"][temp_idx_mask].split("/")[-1] + "_img0_" + str(temp_idx) + "_" + args.split_name + "_" + str(1-now_choice) + ".jpg")) + # cv2.imwrite(os.path.join(args.out_path, image_path_list["path"][temp_idx_mask].split("/")[-1] + "_img0_" + args.split_name + "_" + str(1-now_choice) + ".jpg"), image_array1) + + draw = ImageDraw.ImageDraw(new_image) + draw.rectangle(((bbox_x1-15, bbox_y1-15),(bbox_x2+15, bbox_y2+15)), fill=None, outline='red', width=3) + new_image.save(os.path.join(args.out_path, image_path_list["path"][temp_idx_mask].split("/")[-1] + "_img0_" + str(temp_idx) + "_" + args.split_name + "_" + str(now_choice) + ".jpg")) + + temp_json = {} + temp_json["bbox"] = [int(bbox_x1), int(bbox_y1), int(bbox_x2), int(bbox_y2)] + temp_json["conversations"] = [] + + temp_conversation = {} + temp_conversation["from"] = "human" + temp_conversation["value"] = f"Analyse the the left image and the right image (separated by the black vertical bar). Which image has the object related to \"{caption}\" within the red bounding box?\nA. the left image\nB. the right image\nAnswer with the option's letter from the given choices directly." + temp_json["conversations"].append(temp_conversation) + + temp_conversation = {} + temp_conversation["from"] = "gpt" + temp_conversation["value"] = choice[now_choice] + temp_json["conversations"].append(temp_conversation) + + temp_json["path"] = os.path.join(args.out_path, image_path_list["path"][temp_idx_mask].split("/")[-1] + "_img0_" + str(temp_idx) + "_" + args.split_name) + + new_json.append(temp_json) + + if len(diff_targets_2) > 0: + filtered_targets_2 = iou_filter(np.array(diff_targets_2), 0.5) + + for temp_idx, temp_filtered_target in enumerate(filtered_targets_2): + + if temp_idx == 3: + break + + bbox_x1 = temp_filtered_target[0] + bbox_y1 = temp_filtered_target[1] + bbox_x2 = temp_filtered_target[2] + bbox_y2 = temp_filtered_target[3] + + which_img = temp_filtered_target[5] + mask = mask_list[int(temp_filtered_target[6])] + + if which_img == 2: + + # image 2 + prompt = "background, nothing, 8k" + # print(noun2) + # print(prompt) + new_image = pipeline(prompt=prompt, image=Image.fromarray(image_array2), mask_image=mask, strength=0.85, guidance_scale=0, num_inference_steps=4).images[0] + temp_image_array2 = Image.fromarray(image_array2) + + # mllm captioning + prompt_bbox = f"Please provide a clear description for this region: [{str(bbox_x1)}, {str(bbox_y1)}, {str(bbox_x2)}, {str(bbox_y2)}]." + prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt_bbox + conv = conv_templates["vicuna_v1"].copy() + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') + input_ids = input_ids.to(device=mllm.device, non_blocking=True) + image_tensor = process_images([temp_image_array2], image_processor, mllm.config)[0] + + temperature = 0 + with torch.inference_mode(): + output_ids = mllm.generate( + input_ids.unsqueeze(0), + images=image_tensor.to(dtype=torch.float16, device=mllm.device, non_blocking=True).unsqueeze(0), + image_sizes=[temp_image_array2.size], + do_sample=True if temperature > 0 else False, + temperature=temperature, + top_p=None, + num_beams=1, + max_new_tokens=64, + use_cache=True) # + caption = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] + + # print(caption) + + # caption quality filter + crop_pil_img1 = temp_image_array2.crop((bbox_x1, bbox_y1, bbox_x2, bbox_y2)) + crop_pil_img2 = new_image.crop((bbox_x1, bbox_y1, bbox_x2, bbox_y2)) + crop_inputs_list = blip_processor([crop_pil_img1, crop_pil_img2], [caption], return_tensors="pt", padding=True).to("cuda", torch.float16) + cosine_score = blip_model(**crop_inputs_list, use_itm_head=False).itm_score + # print(cosine_score) + if cosine_score[0][0] < 0.35 or cosine_score[1][0] > 0.2: + continue + + + now_choice = random.choice(rand_choice) + draw = ImageDraw.ImageDraw(temp_image_array2) + draw.rectangle(((bbox_x1-15, bbox_y1-15),(bbox_x2+15, bbox_y2+15)), fill=None, outline='red', width=3) + temp_image_array2.save(os.path.join(args.out_path, image_path_list["path"][temp_idx_mask].split("/")[-1] + "_img1_" + str(temp_idx) + "_" + args.split_name + "_" + str(1-now_choice) + ".jpg")) + # cv2.imwrite(os.path.join(args.out_path, image_path_list["path"][temp_idx_mask].split("/")[-1] + "_img1_" + args.split_name + "_" + str(1-now_choice) + ".jpg"), image_array2) + + draw = ImageDraw.ImageDraw(new_image) + draw.rectangle(((bbox_x1-15, bbox_y1-15),(bbox_x2+15, bbox_y2+15)), fill=None, outline='red', width=3) + new_image.save(os.path.join(args.out_path, image_path_list["path"][temp_idx_mask].split("/")[-1] + "_img1_" + str(temp_idx) + "_" + args.split_name + "_" + str(now_choice) + ".jpg")) + + temp_json = {} + temp_json["bbox"] = [int(bbox_x1), int(bbox_y1), int(bbox_x2), int(bbox_y2)] + temp_json["conversations"] = [] + + temp_conversation = {} + temp_conversation["from"] = "human" + temp_conversation["value"] = f"Analyse the the left image and the right image (separated by the black vertical bar). Which image has the object related to \"{caption}\" within the red bounding box?\nA. the left image\nB. the right image\nAnswer with the option's letter from the given choices directly." + temp_json["conversations"].append(temp_conversation) + + temp_conversation = {} + temp_conversation["from"] = "gpt" + temp_conversation["value"] = choice[now_choice] + temp_json["conversations"].append(temp_conversation) + + temp_json["path"] = os.path.join(args.out_path, image_path_list["path"][temp_idx_mask].split("/")[-1] + "_img1_" + str(temp_idx) + "_" + args.split_name) + + new_json.append(temp_json) + # break + # break + + print(len(new_json)) + # print(new_json) + with open(args.output_file, "w") as f: + f.write(json.dumps(new_json)) diff --git a/Img-Diff-codes/object_removal/run_generate_inpaint.sh b/Img-Diff-codes/object_removal/run_generate_inpaint.sh new file mode 100644 index 000000000..5799571d2 --- /dev/null +++ b/Img-Diff-codes/object_removal/run_generate_inpaint.sh @@ -0,0 +1,11 @@ +CUDA_VISIBLE_DEVICES="0" python generate_inpaint.py \ + --out_path ./inpaint_all_mscoco \ + --vit_path clip-vit-base-patch32 \ + --blip_path blip-itm-large-coco \ + --fastsam_path FastSAM-x.pt \ + --sd_model_path sdxl-turbo \ + --mllm_path llava-v1.6-vicuna-7b \ + --json_path filtered_file_new_edit_all_0.json \ + --output_file inpaint_mllm_all_mscoco_0.json \ + --split_name 0 + diff --git a/Img-Diff-codes/object_replacement/cos_count.py b/Img-Diff-codes/object_replacement/cos_count.py new file mode 100644 index 000000000..a221e2e16 --- /dev/null +++ b/Img-Diff-codes/object_replacement/cos_count.py @@ -0,0 +1,63 @@ +import json +import tqdm + +new_json = [] +for split_name in range(4): + with open(f"./old_after_0625/filtered_file_new_caption_{str(split_name)}.txt", "r") as f: + data = f.readlines() + for temp_line in tqdm.tqdm(data): + temp_json = eval(temp_line) + # print(temp_json) + if temp_json["cos"] > 0.9 and temp_json["cos"] <0.98: + new_json.append(temp_json) + + # if temp_json["cos"] <= 0.9 or temp_json["cos"] >= 0.98: + # new_json.append(temp_json) + + # break + +# print(new_json) +print(len(new_json)) + + +# with open("filtered_file_new_edit_85_98.json", "w") as f: # filtered_file_new_edit_09_098.json +# f.write(json.dumps(new_json)) + + +data = new_json +new_json_0 = [] +new_json_1 = [] +new_json_2 = [] +new_json_3 = [] + +length = len(data) +piece_len = int(length/4) + +for idx, piece in tqdm.tqdm(enumerate(data)): + if idx < piece_len: + new_json_0.append(piece) + elif idx >= piece_len and idx < 2*piece_len: + new_json_1.append(piece) + elif idx >= 2 * piece_len and idx < 3*piece_len: + new_json_2.append(piece) + else: + new_json_3.append(piece) + +# print(length) +print(piece_len) +print(len(new_json_0)) +print(len(new_json_1)) +print(len(new_json_2)) +print(len(new_json_3)) + +with open("filtered_file_new_caption_9_98_0.json", "w") as f: + f.write(json.dumps(new_json_0)) + +with open("filtered_file_new_caption_9_98_1.json", "w") as f: + f.write(json.dumps(new_json_1)) + +with open("filtered_file_new_caption_9_98_2.json", "w") as f: + f.write(json.dumps(new_json_2)) + +with open("filtered_file_new_caption_9_98_3.json", "w") as f: + f.write(json.dumps(new_json_3)) \ No newline at end of file diff --git a/Img-Diff-codes/object_replacement/cos_filter.py b/Img-Diff-codes/object_replacement/cos_filter.py new file mode 100644 index 000000000..d792797b3 --- /dev/null +++ b/Img-Diff-codes/object_replacement/cos_filter.py @@ -0,0 +1,83 @@ +import os +from transformers import CLIPImageProcessor, CLIPModel +import json +import tqdm +from torch.utils.data import DataLoader, Dataset, SequentialSampler +import torch +from PIL import Image +import argparse + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--folder_path', type=str, default="./prompt-to-prompt-with-sdxl/output") + parser.add_argument('--json_path', type=str, default="./gen_0.json") + parser.add_argument('--split_name', type=str, default="0") + parser.add_argument('--clip_vit_path', type=str, default="clip-vit-base-patch32") + args=parser.parse_args() + + return args + + +class InferenceDataset(Dataset): + + def __init__(self, images_path, json_path, processor, split_name): + self.images_path = images_path + with open(json_path, "r") as f: + self.json_file = json.load(f) + self.idx_list = [] + for idx, i in tqdm.tqdm(enumerate(self.json_file)): + if os.path.exists(os.path.join(self.images_path, split_name) + "_" + str(idx) + "_0.jpg") and "---" not in i["output"] and "replaced" not in i["output"].lower(): + self.idx_list.append(idx) + print(len(self.idx_list)) + self.processor = processor + self.split_name = split_name + + def __len__(self) -> int: + return len(self.idx_list) + + @torch.no_grad() + def __getitem__(self, idx: int): + image_path1 = os.path.join(self.images_path, self.split_name) + "_" + str(self.idx_list[idx]) + "_0.jpg" + image1 = Image.open(image_path1).convert('RGB') + + image_path2 = os.path.join(self.images_path, self.split_name) + "_" + str(self.idx_list[idx]) + "_1.jpg" + image2 = Image.open(image_path2).convert('RGB') + + images_tensor = self.processor([image1, image2], return_tensors="pt")['pixel_values'] + + return images_tensor[0], images_tensor[1], os.path.join(self.images_path, self.split_name) + "_" + str(self.idx_list[idx]), self.json_file[self.idx_list[idx]] + + +if __name__ == "__main__": + args = parse_args() + + device = "cuda" + filtered_file = [] + images_path = args.folder_path + json_path = args.json_path + + vision_model = CLIPModel.from_pretrained(args.clip_vit_path).to(device).half() + processor = CLIPImageProcessor.from_pretrained(args.clip_vit_path) + + image_dataset = InferenceDataset(images_path, json_path, processor, args.split_name) + sampler = SequentialSampler(image_dataset) + dataloader = DataLoader(image_dataset, sampler=sampler, batch_size=2, drop_last=False) + + with torch.no_grad(): + for (image1_batch, image2_batch, images_path, json_piece) in tqdm.tqdm(dataloader): + image1_batch_feature = vision_model.get_image_features(image1_batch.to(vision_model.device)) + image2_batch_feature = vision_model.get_image_features(image2_batch.to(vision_model.device)) + + cos_list = torch.cosine_similarity(image1_batch_feature, image2_batch_feature, dim=1) + for temp_idx in range(image1_batch_feature.shape[0]): + cos = cos_list[temp_idx] + temp_json = {} + temp_json["path"] = images_path[temp_idx] + temp_json["cos"] = round(float(cos), 3) + temp_json["input"] = json_piece["input"][temp_idx] + temp_json["output"] = json_piece["output"][temp_idx] + with open(f"./filtered_file_new_caption_{args.split_name}.txt", "a") as txt_f: + txt_f.write(str(temp_json) + "\n") + + + \ No newline at end of file diff --git a/Img-Diff-codes/object_replacement/cos_filter.sh b/Img-Diff-codes/object_replacement/cos_filter.sh new file mode 100644 index 000000000..0385e127a --- /dev/null +++ b/Img-Diff-codes/object_replacement/cos_filter.sh @@ -0,0 +1,4 @@ +CUDA_VISIBLE_DEVICES="0" python cos_filter.py \ + --folder_path "./prompt-to-prompt-with-sdxl/output_new_caption" \ + --json_path "./new_caption_0621_0.json" \ + --split_name 0 \ No newline at end of file diff --git a/Img-Diff-codes/object_replacement/generate_bbox.py b/Img-Diff-codes/object_replacement/generate_bbox.py new file mode 100644 index 000000000..a1851976d --- /dev/null +++ b/Img-Diff-codes/object_replacement/generate_bbox.py @@ -0,0 +1,366 @@ +from ultralytics import FastSAM, YOLO +from transformers import CLIPImageProcessor, CLIPModel +import json +import numpy as np +import tqdm +import cv2 +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader +import difflib +from transformers import BlipProcessor, BlipForImageTextRetrieval +import re +from nltk.corpus import wordnet as wn +from nltk.corpus import words +from nltk import pos_tag +from nltk.stem import WordNetLemmatizer +import argprase + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--vit_path', type=str, default="clip-vit-base-patch32") + parser.add_argument('--blip_path', type=str, default="blip-itm-large-coco") + parser.add_argument('--fastsam_path', type=str, default="FastSAM-x.pt") + parser.add_argument('--json_path', type=str, default="filtered_file_new_edit_09_098_3.json") + parser.add_argument('--output_file', type=str, default="bbox_file_3.json") + + args=parser.parse_args() + + return args + +def is_noun(word): + # print(word) + pos_tagged = pos_tag([word]) + pos = pos_tagged[0][1] + + if not pos in ['NN', 'NNS', 'NNP', 'NNPS']: + return False + + return True + + +def is_adj(word): + # print(word) + pos_tagged = pos_tag([word]) + pos = pos_tagged[0][1] + + if not pos in ["JJ", "JJR", "JJS"]: + return False + + return True + + + +def iou_filter(samples, iou_thresh): + x1 = samples[:, 0] + y1 = samples[:, 1] + x2 = samples[:, 2] + y2 = samples[:, 3] + scores = samples[:, 4] + + areas = (y2 - y1 + 1) * (x2 - x1 + 1) + keep_boxes = [] + index = scores.argsort() # Ascending + + while len(index) > 0: + i = index[0] + keep_boxes.append(i) + + x1_overlap = np.maximum(x1[i], x1[index[1:]]) + y1_overlap = np.maximum(y1[i], y1[index[1:]]) + x2_overlap = np.minimum(x2[i], x2[index[1:]]) + y2_overlap = np.minimum(y2[i], y2[index[1:]]) + + w = np.maximum(0, x2_overlap - x1_overlap + 1) + h = np.maximum(0, y2_overlap - y1_overlap + 1) + overlap_area = w * h + + ious = overlap_area / (areas[i] + areas[index[1:]] - overlap_area) + + idx = np.where(ious <= iou_thresh)[0] + index = index[idx + 1] # update + + return samples[keep_boxes] + + +class InferenceDataset_for_FastSAM(Dataset): + + def __init__(self, json_path): + with open(json_path, "r") as f: + self.image_path = json.load(f) + + def __len__(self) -> int: + return len(self.image_path) + + @torch.no_grad() + def __getitem__(self, idx: int): + # image_array1 = cv2.cvtColor(cv2.imread(self.image_path[idx] + "_0.jpg"), cv2.COLOR_BGR2RGB) + # image_array2 = cv2.cvtColor(cv2.imread(self.image_path[idx] + "_1.jpg"), cv2.COLOR_BGR2RGB) + # image_array2 = cv2.resize(image_array2,(image_array1.shape[1],image_array1.shape[0])) + + return self.image_path[idx] + + + +class InferenceDataset_for_clip(Dataset): + + def __init__(self, image_list): + self.image_list = image_list + + def __len__(self) -> int: + return len(self.image_list) + + @torch.no_grad() + def __getitem__(self, idx: int): + + return self.image_list[idx] + + +class InferenceDataset_for_blip(Dataset): + + def __init__(self, pixel_values): + self.pixel_values = pixel_values + + def __len__(self) -> int: + return len(self.pixel_values) + + @torch.no_grad() + def __getitem__(self, idx: int): + + return self.pixel_values[idx] + + +def compare_text_index(text1, text2): + + text1_split = [] + text2_split = [] + + lemmatizer=WordNetLemmatizer() + + d = difflib.Differ() + diff = d.compare(re.sub(r'[^\w\s]', '', text1.lower().replace(" ", "\n")).splitlines(), re.sub(r'[^\w\s]', '', text2.lower().replace(" ", "\n")).splitlines()) + + + for line in diff: + if line.startswith('+'): + text1_split.append(lemmatizer.lemmatize(line.replace("+ ", ""))) + elif line.startswith('-'): + text2_split.append(lemmatizer.lemmatize(line.replace("- ", ""))) + + text1 = [] + text2 = [] + + for temp_idx, temp_word1 in enumerate(text1_split): + if temp_word1 not in text2_split: + if is_noun(temp_word1): + text1.append(temp_word1) + + for temp_idx, temp_word2 in enumerate(text2_split): + if temp_word2 not in text1_split: + if is_noun(temp_word2): + text2.append(temp_word2) + + return text1, text2 + + + + + +if __name__ == "__main__": + args = parse_args() + + new_json = [] + device = "cuda" + + vision_model = CLIPModel.from_pretrained(args.vit_path).to(device).half() + processor = CLIPImageProcessor.from_pretrained(args.vit_path) + + blip_processor = BlipProcessor.from_pretrained(args.blip_path) + blip_model = BlipForImageTextRetrieval.from_pretrained(args.blip_path, torch_dtype=torch.float16).to(device).half() + + fastSAM_model = FastSAM(args.fastsam_path) + + + image_dataset = InferenceDataset_for_FastSAM(args.json_path) + print(args.json_path) + dataloader_fastsam = DataLoader(image_dataset, batch_size=16, drop_last=False) + + with torch.no_grad(): + for image_path_list in tqdm.tqdm(dataloader_fastsam): + # print(len(image_path_list)) + image_list1 = [] + image_list2 = [] + for temp_idx_list in range(len(image_path_list["path"])): + image_array1 = cv2.cvtColor(cv2.imread(image_path_list["path"][temp_idx_list] + "_0.jpg"), cv2.COLOR_BGR2RGB) + image_list1.append(image_array1) + + image_array2 = cv2.cvtColor(cv2.imread(image_path_list["path"][temp_idx_list] + "_1.jpg"), cv2.COLOR_BGR2RGB) + image_array2 = cv2.resize(image_array2,(image_array1.shape[1],image_array1.shape[0])) + image_list2.append(image_array2) + + masks1 = fastSAM_model(image_list1, retina_masks=True, imgsz=1024, conf=0.05, iou=0.5, verbose=False) + masks2 = fastSAM_model(image_list2, retina_masks=True, imgsz=1024, conf=0.05, iou=0.5, verbose=False) + + for temp_idx_mask in range(len(image_path_list["path"])): + + # print(image_path_list["input"][temp_idx_mask]) + # print(image_path_list["output"][temp_idx_mask]) + + noun1, noun2 = compare_text_index(image_path_list["input"][temp_idx_mask], image_path_list["output"][temp_idx_mask]) + if noun1 == [] and noun2 == []: + continue + + # print(noun1) + # print(noun2) + + temp_mask1 = masks1[temp_idx_mask] + temp_mask2 = masks2[temp_idx_mask] + if len(temp_mask1.boxes.xyxy) + len(temp_mask2.boxes.xyxy) == 0: + continue + + image_array1 = image_list1[temp_idx_mask] + image_array2 = image_list2[temp_idx_mask] + + image_targets = [] + image_targets_pos = [] + diff_targets = [] + + with torch.no_grad(): + for temp_target in temp_mask1.boxes.xyxy: + # crop_img = image_array1[int(temp_target['bbox'][1]):int(temp_target['bbox'][1])+int(temp_target['bbox'][3]),int(temp_target['bbox'][0]):int(temp_target['bbox'][0])+int(temp_target['bbox'][2]),:] + crop_img = image_array1[int(temp_target[1]):int(temp_target[3]),int(temp_target[0]):int(temp_target[2]),:] + img = Image.fromarray(crop_img) + image_targets.append(img) + # image1_targets_pos.append(temp_target['bbox']) + image_targets_pos.append(temp_target.cpu()) + + crop_img_same_pos = image_array2[int(temp_target[1]):int(temp_target[3]),int(temp_target[0]):int(temp_target[2]),:] + img = Image.fromarray(crop_img_same_pos) + image_targets.append(img) + + num_image1_targets = len(image_targets) + + + + + for temp_target in temp_mask2.boxes.xyxy: + # crop_img = image_array1[int(temp_target['bbox'][1]):int(temp_target['bbox'][1])+int(temp_target['bbox'][3]),int(temp_target['bbox'][0]):int(temp_target['bbox'][0])+int(temp_target['bbox'][2]),:] + crop_img = image_array2[int(temp_target[1]):int(temp_target[3]),int(temp_target[0]):int(temp_target[2]),:] + img = Image.fromarray(crop_img) + image_targets.append(img) + # image1_targets_pos.append(temp_target['bbox']) + image_targets_pos.append(temp_target.cpu()) + + crop_img_same_pos = image_array1[int(temp_target[1]):int(temp_target[3]),int(temp_target[0]):int(temp_target[2]),:] + img = Image.fromarray(crop_img_same_pos) + image_targets.append(img) + + # print(len(image_targets)) + + if len(image_targets) == 0: + continue + try: + image_targets_clip = processor(image_targets, return_tensors="pt")['pixel_values'].half() + except: + continue + image_dataset = InferenceDataset_for_clip(image_targets_clip) + dataloader_clip = DataLoader(image_dataset, batch_size=256, drop_last=False) + + image_feature = None + for batch in dataloader_clip: + temp_image_feature = vision_model.get_image_features(batch.to(vision_model.device)) + if image_feature == None: + image_feature = temp_image_feature.to(torch.float32) + else: + image_feature = torch.cat((image_feature, temp_image_feature.to(torch.float32)), dim = 0) + + + temp_noun = [] + temp_noun.extend(noun1) + temp_noun.extend(noun2) + try: + image_targets_blip = blip_processor(image_targets, temp_noun, return_tensors="pt", padding=True).to(device, torch.float16) + except: + continue + + input_ids = image_targets_blip['input_ids'] + attention_mask = image_targets_blip['attention_mask'] + image_dataset = InferenceDataset_for_blip(image_targets_blip['pixel_values']) + dataloader_blip = DataLoader(image_dataset, batch_size=256, drop_last=False) + blip_itm_score = None + for pixel_values in dataloader_blip: + cosine_score = blip_model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, use_itm_head=False).itm_score + if blip_itm_score == None: + blip_itm_score = cosine_score.to(torch.float32) + else: + blip_itm_score = torch.cat((blip_itm_score, cosine_score.to(torch.float32)), dim = 0) + + # print(blip_itm_score) + + for temp_idx_cos in range(0, image_feature.shape[0], 2): + + thresh = 0.35 # 0.35 + # if blip_itm_score[temp_idx_cos][0] < thresh and blip_itm_score[temp_idx_cos + 1][1] < thresh: # and -> or + # continue + + not_match = True # effective object + + for temp_count in range(len(noun1)): + if blip_itm_score[temp_idx_cos][temp_count] > thresh: + not_match = False + break + + if not_match: + for temp_count in range(len(noun2)): + if blip_itm_score[temp_idx_cos + 1][len(noun1) + temp_count] > thresh: + not_match = False + break + + if not_match: + continue + + + cos = torch.cosine_similarity(image_feature[temp_idx_cos], image_feature[temp_idx_cos + 1], dim=0) + if (cos<0.85): # 0.95 0.85 + temp_diff_target = [] + temp_diff_target.extend(image_targets_pos[int(temp_idx_cos/2)]) + temp_diff_target.append(cos.cpu()) + + if temp_idx_cos < num_image1_targets: + temp_diff_target.append(1) + else: + temp_diff_target.append(2) + diff_targets.append(temp_diff_target) + + # print(len(diff_targets)) + + if len(diff_targets) == 0: + continue + + + filtered_targets = iou_filter(np.array(diff_targets), 0.5) + + temp_new_json = {} + temp_new_json["path"] = image_path_list["path"][temp_idx_mask] + temp_filtered_bbox = [] + for temp_idx_bbox, temp_filtered_targets in enumerate(filtered_targets): + if temp_idx_bbox == 10: + break + temp_bbox_num = [] + for num in temp_filtered_targets[0:4]: + temp_bbox_num.append(round(float(num), 1)) + temp_filtered_bbox.append(temp_bbox_num) + temp_new_json["bbox"] = temp_filtered_bbox + + # print(len(temp_filtered_bbox)) + new_json.append(temp_new_json) + + if len(new_json) % 1000 == 0: + print(len(new_json)) + + # break + + print(len(new_json)) + with open(args.output_file, "w") as f: + f.write(json.dumps(new_json)) diff --git a/Img-Diff-codes/object_replacement/generate_bbox.sh b/Img-Diff-codes/object_replacement/generate_bbox.sh new file mode 100644 index 000000000..9bf0749b3 --- /dev/null +++ b/Img-Diff-codes/object_replacement/generate_bbox.sh @@ -0,0 +1,7 @@ +CUDA_VISIBLE_DEVICES="0" python generate_bbox.py \ + --vit_path "clip-vit-base-patch32" \ + --blip_path "blip-itm-large-coco" \ + --fastsam_path "FastSAM-x.pt" \ + --json_path "filtered_file_new_edit_09_098_3.json" \ + --output_file "bbox_file_3.json" + diff --git a/Img-Diff-codes/object_replacement/generate_final_data_new_edit.py b/Img-Diff-codes/object_replacement/generate_final_data_new_edit.py new file mode 100644 index 000000000..4df54b9cf --- /dev/null +++ b/Img-Diff-codes/object_replacement/generate_final_data_new_edit.py @@ -0,0 +1,372 @@ +import sys +sys.path.append("./LLaVA/") + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path +from torch.utils.data import Dataset, DataLoader +from PIL import Image, ImageDraw +import json +import argparse +import tqdm +import torch +import os +from transformers import CLIPImageProcessor, CLIPModel, AutoTokenizer +from transformers import BlipProcessor, BlipForImageTextRetrieval +import random +from copy import deepcopy + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--inp2p_bbox_json_path', type=str, default="./filtered_inp2p_bbox.json") + parser.add_argument('--llava_path', type=str, default="./llava-v1.6-vicuna-7b") + parser.add_argument('--clip_path', type=str, default="./clip-vit-base-patch32") + parser.add_argument('--blip_path', type=str, default="./blip-itm-large-coco") + parser.add_argument('--img_dir', type=str, default="./instructpix2pix") + parser.add_argument('--output_img_dir', type=str, default="./new_edit_img") + parser.add_argument('--device', type=str, default="cuda") + parser.add_argument('--qa_turns', type=int, default=6) + parser.add_argument("--output_file", type=str, default=".") + args=parser.parse_args() + + return args + +def get_sub_list(ori_list, indice_list): + new_list = [] + for i in indice_list: + new_list.append(ori_list[i]) + return new_list + + +class InferenceDataset(Dataset): + + def __init__(self, args): + with open(args.inp2p_bbox_json_path, "r") as f: + self.json_file = json.load(f) + self.args = args + + def __len__(self) -> int: + return len(self.json_file) + + @torch.no_grad() + def __getitem__(self, idx: int): + + + return self.json_file[idx]["path"].replace("./prompt-to-prompt-with-sdxl/output", "/"), self.json_file[idx]["bbox"] + + +# Adopted from https://github.com/mapluisch/LLaVA-CLI-with-multiple-images/blob/main/llava-multi-images.py +def concatenate_images_horizontal(images, bar_width): + # calc total width of imgs + dist between them + total_width = sum(img.width for img in images) + bar_width * (len(images) - 1) + # calc max height from imgs + height = max(img.height for img in images) + + # create new img with calculated dimensions, black bg + new_img = Image.new('RGB', (total_width, height), (0, 0, 0)) + + # init var to track current width pos + current_width = 0 + for img in images: + # paste img in new_img at current width + new_img.paste(img, (current_width, 0)) + # update current width for next img + current_width += img.width + bar_width + + return new_img + +def iou_filter(now_bbox, bbox_list, thresh): + for temp in bbox_list: + x1_overlap = max(now_bbox[0], temp[0]) + y1_overlap = max(now_bbox[1], temp[1]) + x2_overlap = min(now_bbox[2], temp[2]) + y2_overlap = min(now_bbox[3], temp[3]) + + w = max(0, x2_overlap - x1_overlap) + h = max(0, y2_overlap - y1_overlap) + overlap_area = w * h + + iou = overlap_area / ((now_bbox[2] - now_bbox[0])*(now_bbox[3]-now_bbox[1]) + (temp[2]-temp[0])*(temp[3]-temp[1]) - overlap_area) + + if iou > thresh: + return True + + return False + + + + +if __name__ == "__main__": + + args = parse_args() + device = args.device + + + blip_processor = BlipProcessor.from_pretrained(args.blip_path) + blip_model = BlipForImageTextRetrieval.from_pretrained(args.blip_path, torch_dtype=torch.float16).to(device).half() + + + clip_model = CLIPModel.from_pretrained(args.clip_path).to(device).half() + clip_processor = CLIPImageProcessor.from_pretrained(args.clip_path) + clip_tokenizer = AutoTokenizer.from_pretrained(args.clip_path) + + llava_path = args.llava_path + model_path = os.path.expanduser(llava_path) + model_base = None + model_name = get_model_name_from_path(model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, use_flash_attn=True, load_4bit=False) + # model = model.to(device) + + batch_size = 1 + dataset = InferenceDataset(args) + dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=False) + + + + count = 0 + new_json = [] + + for image_path, filtered_targets in tqdm.tqdm(dataloader): + + image_path = args.img_dir + image_path[0][1:] + + image1 = Image.open(image_path + "_0.jpg").convert('RGB') + image2 = Image.open(image_path + "_1.jpg").convert('RGB') + image2 = image2.resize((image1.size[0], image1.size[1])) + image_tensor1 = process_images([image1], image_processor, model.config)[0] + image_tensor2 = process_images([image2], image_processor, model.config)[0] + + image_tensor_list = [] + prompt_list = [] + concat_image_tensor_list = [] + bbox_list = [] + crop_bbox_list = [] + red_bbox_img_list1 = [] + red_bbox_img_list2 = [] + + + if len(filtered_targets) > 0 : + + prompt_bbox = "Please provide a clear description for this region: " # + for temp_idx, temp_target in enumerate(filtered_targets): + if (temp_target[2] - temp_target[0]) * (temp_target[3] - temp_target[1]) < (image1.size[0] * image1.size[1]) / 400: + continue + + if len(concat_image_tensor_list) == args.qa_turns: + break + + temp_image1 = image1.copy() + temp_image2 = image2.copy() + draw1 = ImageDraw.ImageDraw(temp_image1) + draw2 = ImageDraw.ImageDraw(temp_image2) + + extend_width = 5 + if temp_target[0] - extend_width >= 0: + extend_x1 = temp_target[0] - extend_width + else: + extend_x1 = 0 + + if temp_target[1] - extend_width >= 0: + extend_y1 = temp_target[1] - extend_width + else: + extend_y1 = 0 + + if temp_target[2] + extend_width <= image1.size[0]: + extend_x2 = temp_target[2] + extend_width + else: + extend_x2 = image1.size[0] + + if temp_target[3] +extend_width <= image1.size[1]: + extend_y2 = temp_target[3] +extend_width + else: + extend_y2 = image1.size[1] + + crop_bbox_list.append((int(extend_x1), int(extend_y1), int(extend_x2), int(extend_y2))) + + draw1.rectangle(((extend_x1, extend_y1),(extend_x2, extend_y2)), fill=None, outline='red', width=3) + draw2.rectangle(((extend_x1, extend_y1),(extend_x2, extend_y2)), fill=None, outline='red', width=3) + red_bbox_img_list1.append(temp_image1) + red_bbox_img_list2.append(temp_image2) + concat_image = concatenate_images_horizontal([temp_image1, temp_image2], 20) + concat_image.save("./label_img/" + str(temp_idx) + "concat_image_pil.jpg") + concat_image_tensor = process_images([concat_image], image_processor, model.config)[0] + + + image_tensor_list.append(image_tensor1) # img1 + image_tensor_list.append(image_tensor2) # img2 + concat_image_tensor_list.append(concat_image_tensor) + + temp_bbox_x1 = str(round(float(temp_target[0] / image1.size[0]), 2)) + temp_bbox_y1 = str(round(float(temp_target[1] / image1.size[1]), 2)) + temp_bbox_x2 = str(round(float(temp_target[2] / image1.size[0]), 2)) + temp_bbox_y2 = str(round(float(temp_target[3] / image1.size[1]), 2)) + + + + while(len(temp_bbox_x1) < 4): + temp_bbox_x1 = temp_bbox_x1 + "0" + while(len(temp_bbox_y1) < 4): + temp_bbox_y1 = temp_bbox_y1 + "0" + while(len(temp_bbox_x2) < 4): + temp_bbox_x2 = temp_bbox_x2 + "0" + while(len(temp_bbox_y2) < 4): + temp_bbox_y2 = temp_bbox_y2 + "0" + + str_bbox = "[" + temp_bbox_x1 + ", " + temp_bbox_y1 + ", "+ temp_bbox_x2 + ", " + temp_bbox_y2 +"]" + prompt = prompt_bbox + str_bbox + "." + bbox_list.append(str_bbox) + # prompt = "Analyse the the left image and the right image (separated by the black vertical bar). What differences are present within the red-bordered areas of the two images? The box may potentially be empty. Answer these questions in a concise sentence." + prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt + conv = conv_templates["vicuna_v1"].copy() + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') + input_ids = input_ids.to(device=model.device, non_blocking=True) + + prompt_list.append(input_ids) # img1 + prompt_list.append(input_ids) # img2 + + + + if len(prompt_list) == 0: + continue + + input_ids = torch.stack(prompt_list).to(model.device) + image_tensor = torch.stack(image_tensor_list).to(dtype=torch.float16, device=model.device, non_blocking=True) + # concat_image_tensor = torch.stack(concat_image_tensor_list).to(dtype=torch.float16, device=model.device, non_blocking=True) + + temperature = 0 + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensor, + image_sizes=[image1.size] * len(image_tensor_list), + do_sample=True if temperature > 0 else False, + temperature=temperature, + top_p=None, + num_beams=1, + max_new_tokens=64, + use_cache=True) # + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + captions_outputs = outputs + + # similar captions filter + filter_caption_idx = [] + caption_tokens = clip_tokenizer(outputs, padding=True, return_tensors="pt").to("cuda") + caption_text_features = clip_model.get_text_features(**caption_tokens) + # print(caption_text_features.shape) + for temp_idx in range(int(len(outputs)/2)): + cos = torch.cosine_similarity(caption_text_features[2 * temp_idx], caption_text_features[2 * temp_idx + 1], dim=0) + if cos<0.85: #0.9 + filter_caption_idx.append(temp_idx) + + if len(filter_caption_idx) == 0: + continue + + # caption quality filter + final_filter_idx = [] + filter_captions = [] + filter_idx = [] + crop_img_list = [] + for temp_idx in filter_caption_idx: + crop_pil_img1 = image1.crop(crop_bbox_list[temp_idx]) + crop_pil_img2 = image2.crop(crop_bbox_list[temp_idx]) + crop_img_list.append(crop_pil_img1) + crop_img_list.append(crop_pil_img2) + filter_captions.append(outputs[temp_idx * 2]) + filter_captions.append(outputs[temp_idx * 2 + 1]) + + # print(crop_inputs_list) + crop_inputs_list = blip_processor(crop_img_list, filter_captions, return_tensors="pt", padding=True).to("cuda", torch.float16) + cosine_score = blip_model(**crop_inputs_list, use_itm_head=False).itm_score + + for temp_idx in range(len(filter_caption_idx)): + if cosine_score[temp_idx * 2][temp_idx * 2] > 0.35 and cosine_score[temp_idx * 2 + 1][temp_idx * 2 + 1] > 0.35: #0.3 + final_filter_idx.append(filter_caption_idx[temp_idx]) + + if len(final_filter_idx) == 0: + continue + + red_bbox_img_list1 = get_sub_list(red_bbox_img_list1, final_filter_idx) + red_bbox_img_list2 = get_sub_list(red_bbox_img_list2, final_filter_idx) + concat_image_tensor_list = get_sub_list(concat_image_tensor_list, final_filter_idx) + concat_image_tensor = torch.stack(concat_image_tensor_list).to(dtype=torch.float16, device=model.device, non_blocking=True) + + concat_image_input_ids_list = [] + for temp_idx in final_filter_idx: + prompt = "Analyse the the left image and the right image (separated by the black vertical bar). The detail within the red bounding box in the left image is: " + outputs[temp_idx * 2] + ", " + \ + "while the detail within the red bounding box in the right image is: " + outputs[temp_idx * 2 + 1] + ". What is their difference? Answer with a few concise sentences."# + prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt + conv = conv_templates["vicuna_v1"].copy() + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') + input_ids = input_ids.to(device=model.device, non_blocking=True) + + concat_image_input_ids_list.append(input_ids) + + + max_len = 0 + for temp_idx in range(len(concat_image_input_ids_list)): + max_len = max(max_len, len(concat_image_input_ids_list[temp_idx])) + + for temp_idx in range(len(concat_image_input_ids_list)): + if len(concat_image_input_ids_list[temp_idx]) < max_len: + concat_image_input_ids_list[temp_idx] = torch.cat((torch.zeros(max_len-len(concat_image_input_ids_list[temp_idx])).to(model.device), concat_image_input_ids_list[temp_idx])).long() + + concat_image_input_ids = torch.stack(concat_image_input_ids_list) + + temperature = 0 + with torch.inference_mode(): + output_ids = model.generate( + concat_image_input_ids.to(model.device), + images=concat_image_tensor, + image_sizes=[concat_image.size] * len(concat_image_tensor_list), + do_sample=True if temperature > 0 else False, + temperature=temperature, + top_p=None, + num_beams=1, + max_new_tokens=128, + use_cache=True) # + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + + + + for enumerate_idx, temp_idx in enumerate(final_filter_idx): + temp_json = {} + + temp_json["conversations"] = [] + temp_json["bbox"] = bbox_list[temp_idx] + + temp_json["captions1"] = captions_outputs[temp_idx * 2] + temp_json["captions2"] = captions_outputs[temp_idx * 2 + 1] + + temp_conversation = {} + temp_conversation["from"] = "human" + temp_conversation["value"] = "Analyse the the left image and the right image (separated by the black vertical bar). What is the difference between the red bounding box area in each image? Answer the question in a few concise sentences." + temp_json["conversations"].append(temp_conversation) + + temp_conversation = {} + temp_conversation["from"] = "gpt" + temp_conversation["value"] = outputs[enumerate_idx] + temp_json["conversations"].append(temp_conversation) + + temp_json["path"] = os.path.join(args.output_img_dir, image_path.split("/")[-1] + "_" + str(enumerate_idx)) + red_bbox_img_list1[enumerate_idx].save(temp_json["path"] + "_0.jpg") + red_bbox_img_list2[enumerate_idx].save(temp_json["path"] + "_1.jpg") + + new_json.append(temp_json) + + + + with open(os.path.join(args.output_file), "w") as new_f: + new_f.write(json.dumps(new_json)) + diff --git a/Img-Diff-codes/object_replacement/generate_final_data_new_edit.sh b/Img-Diff-codes/object_replacement/generate_final_data_new_edit.sh new file mode 100644 index 000000000..0af0a304e --- /dev/null +++ b/Img-Diff-codes/object_replacement/generate_final_data_new_edit.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES="0" python generate_final_data_new_edit.py \ + --inp2p_bbox_json_path ./bbox_new_edit_09_098_only_target_025sam_0.json \ + --llava_path llava-v1.6-vicuna-7b \ + --clip_path clip-vit-base-patch32 \ + --blip_path blip-itm-large-coco \ + --img_dir ./prompt-to-prompt-with-sdxl/output \ + --output_img_dir ./filtered_new_edit_data_9_98_only_target_025sam \ + --qa_turns 5 \ + --output_file ./generate_final_data_only_target_9_98_025sam_0.json \ No newline at end of file diff --git a/Img-Diff-codes/pairs_generator/gen.py b/Img-Diff-codes/pairs_generator/gen.py new file mode 100644 index 000000000..3dc152140 --- /dev/null +++ b/Img-Diff-codes/pairs_generator/gen.py @@ -0,0 +1,99 @@ +import sys +sys.path.append("./FastChat") + +import argparse +import json +import re +import time +from transformers import LlamaForCausalLM +from transformers import LlamaTokenizer +import torch +import tqdm +from fastchat.model import get_conversation_template +import tqdm +import random +import transformers + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--vicuna_path', type=str, default="vicuna-13b-v1.5") + parser.add_argument('--json_path', type=str, default="./data.json") + parser.add_argument('--output_path', type=str, default="./output.json") + args=parser.parse_args() + + return args + +def seed_everything(seed): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +if __name__ == "__main__": + + args = parse_args() + answer_json = [] + + model_path = args.vicuna_path + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = LlamaForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=torch.float16, _attn_implementation="flash_attention_2" + ).half().cuda() + + with open(args.json_path, "r") as f: + data = json.load(f) + + for temp_caption in tqdm.tqdm(data): + # temp_caption = temp_caption["conversations"][1]["value"] + with torch.no_grad(): + with torch.inference_mode(): + # for temp_idx in range(5): + + msg = "Here is a sentence: \"" + temp_caption + "\". Please replace one entity in this sentence with another entity, such as an animal, a vehicle, or a piece of furniture. Please only answer with the replaced sentence." + # print(msg) + conv = get_conversation_template(model_path) + conv.append_message(conv.roles[0], msg) + conv.append_message(conv.roles[1], None) + PROMPT = conv.get_prompt() + ids = tokenizer.encode(PROMPT) + input_ids = torch.LongTensor([ids]).to("cuda") + + + + seed_everything(random.randint(1,10000)) + + out = model.generate( + input_ids=input_ids, + max_new_tokens=128, + do_sample=True, + temperature=0.8 + ) + out_text = tokenizer.decode(out[0]) + # out_text = tokenizer.batch_decode(out) + + answer = out_text.replace(PROMPT, "").replace("\nEND", "").replace("", "").replace("", "").strip() + + + if "replac" in answer.lower() or "modified" in answer.lower() or "become" in answer.lower(): + continue + + if "---" in answer: + answer = answer.split("---")[-1].strip() + + + # print(answer) + # print(temp_caption) + # print(answer) + temp_json = {"input":temp_caption, "output":answer} + + # print(temp_json) + answer_json.append(temp_json) + + temp_caption = answer + + # break + + with open(args.output_path, "w") as new_f: + new_f.write(json.dumps(answer_json)) \ No newline at end of file diff --git a/Img-Diff-codes/pairs_generator/gen.sh b/Img-Diff-codes/pairs_generator/gen.sh new file mode 100644 index 000000000..c8081b6cb --- /dev/null +++ b/Img-Diff-codes/pairs_generator/gen.sh @@ -0,0 +1,5 @@ +CUDA_VISIBLE_DEVICES="0" python gen.py \ + --vicuna_path vicuna-7b-v1.5 \ + --json_path data.json \ + --output_path ./gen_llava_0.json + diff --git a/Img-Diff-codes/pairs_generator/gen_new_data_ddp.py b/Img-Diff-codes/pairs_generator/gen_new_data_ddp.py new file mode 100644 index 000000000..90c8c4103 --- /dev/null +++ b/Img-Diff-codes/pairs_generator/gen_new_data_ddp.py @@ -0,0 +1,81 @@ +import torch +from prompt_to_prompt_pipeline import Prompt2PromptPipeline +import tqdm +import argparse +import json +import os +from torch.utils.data import DataLoader, Dataset, SequentialSampler + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', type=str, default="stable-diffusion-xl-base-1.0") + parser.add_argument('--json_path', type=str, default="./data.json") + parser.add_argument('--output_path', type=str, default="./output/") + parser.add_argument("--local-rank", type=int) + args=parser.parse_args() + + return args + + +class InferenceDataset(Dataset): + + def __init__(self, json_path): + with open(json_path, "r") as f: + self.data = json.load(f) + + def __len__(self) -> int: + return len(self.data) + + @torch.no_grad() + def __getitem__(self, idx: int): + return self.data[idx]["input"], self.data[idx]["output"], self.data[idx]["id"] + + +if __name__ == "__main__": + args = parse_args() + local_rank=args.local_rank + # print(local_rank) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group('nccl', init_method='env://') + device = torch.device(f'cuda:{args.local_rank}') + + model_path = args.model_path + pipe = Prompt2PromptPipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device) + # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + + seed = 864 + g_cpu = torch.Generator().manual_seed(seed) + + dataset = InferenceDataset(args.json_path) + sampler = torch.utils.data.distributed.DistributedSampler(dataset) + dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, drop_last=False) + # pipe=torch.nn.parallel.DistributedDataParallel(pipe, device_ids=[args.local_rank]) + + + + + + # file_name = args.json_path.replace("../gen_llava_", "").replace(".json", "") + cross_attention_kwargs = {"edit_type": "refine", + "n_self_replace": 0.4, + "n_cross_replace": {"default_": 1.0, "confetti": 0.8}, + } + + with torch.no_grad(): + for temp_idx, (temp_input, temp_output, temp_img_id) in enumerate(tqdm.tqdm(dataloader)): + + if "replaced" in temp_output[0].lower() or "modified" in temp_output[0].lower(): + continue + + if "---" in temp_output[0]: + continue + + # try: + prompts = [temp_input[0].strip("\""), temp_output[0].strip("\"")] + image = pipe(prompts, cross_attention_kwargs=cross_attention_kwargs, generator=g_cpu) + + + for idx, img in enumerate(image['images']): + img.save(os.path.join(args.output_path, str(temp_img_id[0]) + f"_{str(idx)}.jpg")) + # except: + # continue diff --git a/Img-Diff-codes/pairs_generator/gen_sdxl_new_data_ddp.sh b/Img-Diff-codes/pairs_generator/gen_sdxl_new_data_ddp.sh new file mode 100644 index 000000000..dc6e7e255 --- /dev/null +++ b/Img-Diff-codes/pairs_generator/gen_sdxl_new_data_ddp.sh @@ -0,0 +1,4 @@ +python -m torch.distributed.launch --nproc_per_node=4 gen_new_data_ddp.py \ + --model_path stable-diffusion-xl-base-1.0 \ + --json_path ./gen_vg.json \ + --output_path ./output_vg_ddp \ No newline at end of file diff --git a/Img-Diff-codes/pairs_generator/processors.py b/Img-Diff-codes/pairs_generator/processors.py new file mode 100644 index 000000000..6ee454c33 --- /dev/null +++ b/Img-Diff-codes/pairs_generator/processors.py @@ -0,0 +1,596 @@ +from __future__ import annotations + +import abc +from typing import Dict, List, Optional, Tuple, Union +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.models.attention import Attention + + +class P2PCrossAttnProcessor: + def __init__(self, controller, place_in_unet): + super().__init__() + self.controller = controller + self.place_in_unet = place_in_unet + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + is_cross = encoder_hidden_states is not None + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + + # one line change + self.controller(attention_probs, is_cross, self.place_in_unet) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +def create_controller( + prompts: List[str], cross_attention_kwargs: Dict, num_inference_steps: int, tokenizer, device, attn_res +) -> AttentionControl: + edit_type = cross_attention_kwargs.get("edit_type", None) + local_blend_words = cross_attention_kwargs.get("local_blend_words", None) + equalizer_words = cross_attention_kwargs.get("equalizer_words", None) + equalizer_strengths = cross_attention_kwargs.get("equalizer_strengths", None) + n_cross_replace = cross_attention_kwargs.get("n_cross_replace", 0.4) + n_self_replace = cross_attention_kwargs.get("n_self_replace", 0.4) + + # only replace + if edit_type == "replace" and local_blend_words is None: + return AttentionReplace( + prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, attn_res=attn_res + ) + + # replace + localblend + if edit_type == "replace" and local_blend_words is not None: + lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) + return AttentionReplace( + prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device, attn_res=attn_res + ) + + # only refine + if edit_type == "refine" and local_blend_words is None: + return AttentionRefine( + prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, attn_res=attn_res + ) + + # refine + localblend + if edit_type == "refine" and local_blend_words is not None: + lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) + return AttentionRefine( + prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device, attn_res=attn_res + ) + + # only reweight + if edit_type == "reweight" and local_blend_words is None: + assert ( + equalizer_words is not None and equalizer_strengths is not None + ), "To use reweight edit, please specify equalizer_words and equalizer_strengths." + assert len(equalizer_words) == len( + equalizer_strengths + ), "equalizer_words and equalizer_strengths must be of same length." + equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) + return AttentionReweight( + prompts, + num_inference_steps, + n_cross_replace, + n_self_replace, + tokenizer=tokenizer, + device=device, + equalizer=equalizer, + attn_res=attn_res, + ) + + # reweight and localblend + if edit_type == "reweight" and local_blend_words: + assert ( + equalizer_words is not None and equalizer_strengths is not None + ), "To use reweight edit, please specify equalizer_words and equalizer_strengths." + assert len(equalizer_words) == len( + equalizer_strengths + ), "equalizer_words and equalizer_strengths must be of same length." + equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) + lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) + return AttentionReweight( + prompts, + num_inference_steps, + n_cross_replace, + n_self_replace, + tokenizer=tokenizer, + device=device, + equalizer=equalizer, + attn_res=attn_res, + local_blend=lb, + ) + + raise ValueError(f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight.") + + +class AttentionControl(abc.ABC): + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + @property + def num_uncond_att_layers(self): + return 0 + + @abc.abstractmethod + def forward(self, attn, is_cross: bool, place_in_unet: str): + raise NotImplementedError + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= self.num_uncond_att_layers: + h = attn.shape[0] + attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + self.between_steps() + return attn + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + def __init__(self, attn_res=None): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + self.attn_res = attn_res + + +class EmptyControl(AttentionControl): + def forward(self, attn, is_cross: bool, place_in_unet: str): + return attn + + +class AttentionStore(AttentionControl): + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[1] <= 32**2: # avoid memory overhead + self.step_store[key].append(attn) + return attn + + def between_steps(self): + if len(self.attention_store) == 0: + self.attention_store = self.step_store + else: + for key in self.attention_store: + for i in range(len(self.attention_store[key])): + self.attention_store[key][i] += self.step_store[key][i] + self.step_store = self.get_empty_store() + + def get_average_attention(self): + average_attention = { + key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store + } + return average_attention + + def reset(self): + super(AttentionStore, self).reset() + self.step_store = self.get_empty_store() + self.attention_store = {} + + def __init__(self, attn_res=None): + super(AttentionStore, self).__init__(attn_res) + self.step_store = self.get_empty_store() + self.attention_store = {} + + + +class LocalBlend: + def __call__(self, x_t, attention_store): + # note that this code works on the latent level! + k = 1 + # maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] # These are the numbers because we want to take layers that are 256 x 256, I think this can be changed to something smarter...like, get all attentions where thesecond dim is self.attn_res[0] * self.attn_res[1] in up and down cross. + maps = [m for m in attention_store["down_cross"] + attention_store["mid_cross"] + attention_store["up_cross"] if m.shape[1] == self.attn_res[0] * self.attn_res[1]] + maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, self.attn_res[0], self.attn_res[1], self.max_num_words) for item in maps] + maps = torch.cat(maps, dim=1) + maps = (maps * self.alpha_layers).sum(-1).mean(1) # since alpha_layers is all 0s except where we edit, the product zeroes out all but what we change. Then, the sum adds the values of the original and what we edit. Then, we average across dim=1, which is the number of layers. + mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) + mask = F.interpolate(mask, size=(x_t.shape[2:])) + mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] + mask = mask.gt(self.threshold) + + mask = mask[:1] + mask[1:] + mask = mask.to(torch.float16) + + x_t = x_t[:1] + mask * (x_t - x_t[:1]) # x_t[:1] is the original image. mask*(x_t - x_t[:1]) zeroes out the original image and removes the difference between the original and each image we are generating (mostly just one). Then, it applies the mask on the image. That is, it's only keeping the cells we want to generate. + return x_t + + def __init__( + self, prompts: List[str], words: [List[List[str]]], tokenizer, device, threshold=0.3, attn_res=None + ): + self.max_num_words = 77 + self.attn_res = attn_res + + alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words) + for i, (prompt, words_) in enumerate(zip(prompts, words)): + if isinstance(words_, str): + words_ = [words_] + for word in words_: + ind = get_word_inds(prompt, word, tokenizer) + alpha_layers[i, :, :, :, :, ind] = 1 + self.alpha_layers = alpha_layers.to(device) # a one-hot vector where the 1s are the words we modify (source and target) + self.threshold = threshold + + +class AttentionControlEdit(AttentionStore, abc.ABC): + def step_callback(self, x_t): + if self.local_blend is not None: + x_t = self.local_blend(x_t, self.attention_store) + return x_t + + def replace_self_attention(self, attn_base, att_replace): + if att_replace.shape[2] <= self.attn_res[0]**2: + return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) + else: + return att_replace + + @abc.abstractmethod + def replace_cross_attention(self, attn_base, att_replace): + raise NotImplementedError + + def forward(self, attn, is_cross: bool, place_in_unet: str): + super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) + if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): + h = attn.shape[0] // (self.batch_size) + attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) + attn_base, attn_replace = attn[0], attn[1:] + if is_cross: + alpha_words = self.cross_replace_alpha[self.cur_step] + attn_replace_new = ( + self.replace_cross_attention(attn_base, attn_replace) * alpha_words + + (1 - alpha_words) * attn_replace + ) + attn[1:] = attn_replace_new + else: + attn[1:] = self.replace_self_attention(attn_base, attn_replace) + attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) + return attn + + def __init__( + self, + prompts, + num_steps: int, + cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], + self_replace_steps: Union[float, Tuple[float, float]], + local_blend: Optional[LocalBlend], + tokenizer, + device, + attn_res=None, + ): + super(AttentionControlEdit, self).__init__(attn_res=attn_res) + # add tokenizer and device here + + self.tokenizer = tokenizer + self.device = device + + self.batch_size = len(prompts) + self.cross_replace_alpha = get_time_words_attention_alpha( + prompts, num_steps, cross_replace_steps, self.tokenizer + ).to(self.device) + if isinstance(self_replace_steps, float): + self_replace_steps = 0, self_replace_steps + self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) + self.local_blend = local_blend + + +class AttentionReplace(AttentionControlEdit): + def replace_cross_attention(self, attn_base, att_replace): + return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper) + + def __init__( + self, + prompts, + num_steps: int, + cross_replace_steps: float, + self_replace_steps: float, + local_blend: Optional[LocalBlend] = None, + tokenizer=None, + device=None, + attn_res=None, + ): + super(AttentionReplace, self).__init__( + prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res + ) + self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device) + + +class AttentionRefine(AttentionControlEdit): + def replace_cross_attention(self, attn_base, att_replace): + attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) + attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) + return attn_replace + + def __init__( + self, + prompts, + num_steps: int, + cross_replace_steps: float, + self_replace_steps: float, + local_blend: Optional[LocalBlend] = None, + tokenizer=None, + device=None, + attn_res=None + ): + super(AttentionRefine, self).__init__( + prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res + ) + self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer) + self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device) + self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) + + +class AttentionReweight(AttentionControlEdit): + def replace_cross_attention(self, attn_base, att_replace): + if self.prev_controller is not None: + attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) + attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] + return attn_replace + + def __init__( + self, + prompts, + num_steps: int, + cross_replace_steps: float, + self_replace_steps: float, + equalizer, + local_blend: Optional[LocalBlend] = None, + controller: Optional[AttentionControlEdit] = None, + tokenizer=None, + device=None, + attn_res=None, + ): + super(AttentionReweight, self).__init__( + prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res + ) + self.equalizer = equalizer.to(self.device) + self.prev_controller = controller + + +### util functions for all Edits +def update_alpha_time_word( + alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None +): + if isinstance(bounds, float): + bounds = 0, bounds + start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) + if word_inds is None: + word_inds = torch.arange(alpha.shape[2]) + alpha[:start, prompt_ind, word_inds] = 0 + alpha[start:end, prompt_ind, word_inds] = 1 + alpha[end:, prompt_ind, word_inds] = 0 + return alpha + + +def get_time_words_attention_alpha( + prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77 +): + if not isinstance(cross_replace_steps, dict): + cross_replace_steps = {"default_": cross_replace_steps} + if "default_" not in cross_replace_steps: + cross_replace_steps["default_"] = (0.0, 1.0) + alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) + for i in range(len(prompts) - 1): + alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i) + for key, item in cross_replace_steps.items(): + if key != "default_": + inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] + for i, ind in enumerate(inds): + if len(ind) > 0: + alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) + alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) + return alpha_time_words + + +### util functions for LocalBlend and ReplacementEdit +def get_word_inds(text: str, word_place: int, tokenizer): + split_text = text.split(" ") + if isinstance(word_place, str): + word_place = [i for i, word in enumerate(split_text) if word_place == word] + elif isinstance(word_place, int): + word_place = [word_place] + out = [] + if len(word_place) > 0: + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + cur_len, ptr = 0, 0 + + for i in range(len(words_encode)): + cur_len += len(words_encode[i]) + if ptr in word_place: + out.append(i + 1) + if cur_len >= len(split_text[ptr]): + ptr += 1 + cur_len = 0 + return np.array(out) + + +### util functions for ReplacementEdit +def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): + words_x = x.split(" ") + words_y = y.split(" ") + if len(words_x) != len(words_y): + raise ValueError( + f"attention replacement edit can only be applied on prompts with the same length" + f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words." + ) + inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] + inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] + inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] + mapper = np.zeros((max_len, max_len)) + i = j = 0 + cur_inds = 0 + while i < max_len and j < max_len: + if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: + inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] + if len(inds_source_) == len(inds_target_): + mapper[inds_source_, inds_target_] = 1 + else: + ratio = 1 / len(inds_target_) + for i_t in inds_target_: + mapper[inds_source_, i_t] = ratio + cur_inds += 1 + i += len(inds_source_) + j += len(inds_target_) + elif cur_inds < len(inds_source): + mapper[i, j] = 1 + i += 1 + j += 1 + else: + mapper[j, j] = 1 + i += 1 + j += 1 + + # return torch.from_numpy(mapper).float() + return torch.from_numpy(mapper).to(torch.float16) + + +def get_replacement_mapper(prompts, tokenizer, max_len=77): + x_seq = prompts[0] + mappers = [] + for i in range(1, len(prompts)): + mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) + mappers.append(mapper) + return torch.stack(mappers) + + +### util functions for ReweightEdit +def get_equalizer( + text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer +): + if isinstance(word_select, (int, str)): + word_select = (word_select,) + equalizer = torch.ones(len(values), 77) + values = torch.tensor(values, dtype=torch.float32) + for i, word in enumerate(word_select): + inds = get_word_inds(text, word, tokenizer) + equalizer[:, inds] = torch.FloatTensor(values[i]) + return equalizer + + +### util functions for RefinementEdit +class ScoreParams: + def __init__(self, gap, match, mismatch): + self.gap = gap + self.match = match + self.mismatch = mismatch + + def mis_match_char(self, x, y): + if x != y: + return self.mismatch + else: + return self.match + + +def get_matrix(size_x, size_y, gap): + matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) + matrix[0, 1:] = (np.arange(size_y) + 1) * gap + matrix[1:, 0] = (np.arange(size_x) + 1) * gap + return matrix + + +def get_traceback_matrix(size_x, size_y): + matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) + matrix[0, 1:] = 1 + matrix[1:, 0] = 2 + matrix[0, 0] = 4 + return matrix + + +def global_align(x, y, score): + matrix = get_matrix(len(x), len(y), score.gap) + trace_back = get_traceback_matrix(len(x), len(y)) + for i in range(1, len(x) + 1): + for j in range(1, len(y) + 1): + left = matrix[i, j - 1] + score.gap + up = matrix[i - 1, j] + score.gap + diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) + matrix[i, j] = max(left, up, diag) + if matrix[i, j] == left: + trace_back[i, j] = 1 + elif matrix[i, j] == up: + trace_back[i, j] = 2 + else: + trace_back[i, j] = 3 + return matrix, trace_back + + +def get_aligned_sequences(x, y, trace_back): + x_seq = [] + y_seq = [] + i = len(x) + j = len(y) + mapper_y_to_x = [] + while i > 0 or j > 0: + if trace_back[i, j] == 3: + x_seq.append(x[i - 1]) + y_seq.append(y[j - 1]) + i = i - 1 + j = j - 1 + mapper_y_to_x.append((j, i)) + elif trace_back[i][j] == 1: + x_seq.append("-") + y_seq.append(y[j - 1]) + j = j - 1 + mapper_y_to_x.append((j, -1)) + elif trace_back[i][j] == 2: + x_seq.append(x[i - 1]) + y_seq.append("-") + i = i - 1 + elif trace_back[i][j] == 4: + break + mapper_y_to_x.reverse() + return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) + + +def get_mapper(x: str, y: str, tokenizer, max_len=77): + x_seq = tokenizer.encode(x) + y_seq = tokenizer.encode(y) + score = ScoreParams(0, 1, -1) + matrix, trace_back = global_align(x_seq, y_seq, score) + mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] + alphas = torch.ones(max_len) + alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() + mapper = torch.zeros(max_len, dtype=torch.int64) + mapper[: mapper_base.shape[0]] = mapper_base[:, 1] + mapper[mapper_base.shape[0] :] = len(y_seq) + torch.arange(max_len - len(y_seq)) + return mapper, alphas + + +def get_refinement_mapper(prompts, tokenizer, max_len=77): + x_seq = prompts[0] + mappers, alphas = [], [] + for i in range(1, len(prompts)): + mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) + mappers.append(mapper) + alphas.append(alpha) + return torch.stack(mappers), torch.stack(alphas) diff --git a/Img-Diff-codes/pairs_generator/prompt_to_prompt_pipeline.py b/Img-Diff-codes/pairs_generator/prompt_to_prompt_pipeline.py new file mode 100644 index 000000000..6a30c688c --- /dev/null +++ b/Img-Diff-codes/pairs_generator/prompt_to_prompt_pipeline.py @@ -0,0 +1,473 @@ +from typing import Any, Callable, List +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline +from processors import * + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class Prompt2PromptPipeline(StableDiffusionXLPipeline): + r""" + Args: + Prompt-to-Prompt-Pipeline for text-to-image generation using Stable Diffusion. This model inherits from + [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for + all the pipelines (such as downloading or saving, running on a particular device, etc.) + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler + ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + def _aggregate_and_get_attention_maps_per_token(self, with_softmax): + attention_maps = self.controller.aggregate_attention( + from_where=("up_cross", "down_cross", "mid_cross"), + # from_where=("up", "down"), + # from_where=("down",) + ) + attention_maps_list = self._get_attention_maps_list( + attention_maps=attention_maps, with_softmax=with_softmax + ) + return attention_maps_list + + @staticmethod + def _get_attention_maps_list( + attention_maps: torch.Tensor, with_softmax + ) -> List[torch.Tensor]: + attention_maps *= 100 + + if with_softmax: + attention_maps = torch.nn.functional.softmax(attention_maps, dim=-1) + + attention_maps_list = [ + attention_maps[:, :, i] for i in range(attention_maps.shape[2]) + ] + return attention_maps_list + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + attn_res=None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + The keyword arguments to configure the edit are: + - edit_type (`str`). The edit type to apply. Can be either of `replace`, `refine`, `reweight`. + - n_cross_replace (`int`): Number of diffusion steps in which cross attention should be replaced + - n_self_replace (`int`): Number of diffusion steps in which self attention should be replaced + - local_blend_words(`List[str]`, *optional*, default to `None`): Determines which area should be + changed. If None, then the whole image can be changed. + - equalizer_words(`List[str]`, *optional*, default to `None`): Required for edit type `reweight`. + Determines which words should be enhanced. + - equalizer_strengths (`List[float]`, *optional*, default to `None`) Required for edit type `reweight`. + Determines which how much the words in `equalizer_words` should be enhanced. + + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + if attn_res is None: + attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32)) + self.attn_res = attn_res + + self.controller = create_controller( + prompt, cross_attention_kwargs, num_inference_steps, tokenizer=self.tokenizer, device=self.device, attn_res=self.attn_res + ) + self.register_attention_control(self.controller) # add attention controller + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + latents[1] = latents[0] + + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=self.text_encoder_2.config.projection_dim # if none should be changed to enc1 + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # step callback + latents = self.controller.step_callback(latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 8. Post-processing + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + + def register_attention_control(self, controller): + attn_procs = {} + cross_att_count = 0 + for name in self.unet.attn_processors.keys(): + None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim + if name.startswith("mid_block"): + self.unet.config.block_out_channels[-1] + place_in_unet = "mid" + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + list(reversed(self.unet.config.block_out_channels))[block_id] + place_in_unet = "up" + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + self.unet.config.block_out_channels[block_id] + place_in_unet = "down" + else: + continue + cross_att_count += 1 + attn_procs[name] = P2PCrossAttnProcessor(controller=controller, place_in_unet=place_in_unet) + + self.unet.set_attn_processor(attn_procs) + controller.num_att_layers = cross_att_count diff --git a/README.md b/README.md index 8189d7a73..422d19f57 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ We release **Img-Diff**, A high-quality synthesis dataset focusing on describin ## Codes and Data Recipes -- The original codes are organized and presented in [Img-Diff](https://github.com/modelscope/data-juicer/tree/ImgDiff/Img-Diff). +- The original codes are organized and presented in [Img-Diff-codes](https://github.com/modelscope/data-juicer/tree/ImgDiff/Img-Diff-codes). - The codes and data recipes in data-juicer format will be released soon.