-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathediting.py
112 lines (82 loc) · 3.56 KB
/
editing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torchvision
import os
import gc
import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from lora_w2w import LoRAw2w
from transformers import AutoTokenizer, PretrainedConfig
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
######## Editing Utilities
def get_direction(df, label, pinverse, return_dim, device):
### get labels
labels = []
for folder in list(df.index):
labels.append(df.loc[folder][label])
labels = torch.Tensor(labels).to(device).bfloat16()
### solve least squares
direction = (pinverse@labels).unsqueeze(0)
if return_dim == 1000:
return direction
else:
direction = torch.cat((direction, torch.zeros((1, return_dim-1000)).to(device)), dim=1)
return direction
def debias(direction, label, df, pinverse, device):
### get labels
labels = []
for folder in list(df.index):
labels.append(df.loc[folder][label])
labels = torch.Tensor(labels).to(device).bfloat16()
### solve least squares
d = (pinverse@labels).unsqueeze(0)
###align dimensionalities of the two vectors
if direction.shape[1] == 1000:
pass
else:
d = torch.cat((d, torch.zeros((1, direction.shape[1]-1000)).to(device)), dim=1)
#remove this component from the direction
direction = direction - (([email protected])/(torch.norm(d)**2))*d
return direction
@torch.no_grad
def edit_inference(network, edited_weights, unet, vae, text_encoder, tokenizer, prompt, negative_prompt, guidance_scale, noise_scheduler, ddim_steps, start_noise, seed, generator, device):
original_weights = network.proj.clone()
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
if t>start_noise:
pass
elif t<=start_noise:
network.proj = torch.nn.Parameter(edited_weights)
network.reset()
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
#reset weights back to original
network.proj = torch.nn.Parameter(original_weights)
network.reset()
return image