Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
lawrence-cj authored Nov 26, 2024
2 parents b1be116 + 7d0d659 commit 34bef75
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 22 deletions.
26 changes: 19 additions & 7 deletions app/app_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@
INFER_SPEED = 0


def norm_ip(img, low, high):
img.clamp_(min=low, max=high)
img.sub_(low).div_(max(high - low, 1e-5))
return img


def open_db():
db = sqlite3.connect(COUNTER_DB)
db.execute("CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)")
Expand Down Expand Up @@ -285,13 +291,19 @@ def generate(
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
print(img)
else:
if num_imgs > 1:
nrow = 2
else:
nrow = 1
img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
img = [Image.fromarray(img.astype(np.uint8))]
img = [
Image.fromarray(
norm_ip(img, -1, 1)
.mul(255)
.add_(0.5)
.clamp_(0, 255)
.permute(1, 2, 0)
.to("cpu", torch.uint8)
.numpy()
.astype(np.uint8)
)
for img in images
]

torch.cuda.empty_cache()

Expand Down
3 changes: 1 addition & 2 deletions app/sana_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,9 @@ def forward(
self.latent_size_w,
generator=generator,
device=self.device,
dtype=self.weight_dtype,
)
else:
z = latents.to(self.weight_dtype).to(self.device)
z = latents.to(self.device)
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
if self.vis_sampler == "flow_euler":
flow_solver = FlowEuler(
Expand Down
104 changes: 104 additions & 0 deletions configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
data:
data_dir: [data/data_public/dir1]
image_size: 1024
caption_proportion:
prompt: 1
external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
external_clipscore_suffixes:
- _InternVL2-26B_clip_score
- _VILA1-5-13B_clip_score
- _prompt_clip_score
clip_thr_temperature: 0.1
clip_thr: 25.0
load_text_feat: false
load_vae_feat: false
transform: default_train
type: SanaWebDatasetMS
sort_dataset: false
# model config
model:
model: SanaMS_1600M_P1_D20
image_size: 1024
mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
fp32_attention: true
load_from:
resume_from:
aspect_ratio_type: ASPECT_RATIO_1024
multi_scale: true
#pe_interpolation: 1.
attn_type: linear
ffn_type: glumbconv
mlp_acts:
- silu
- silu
-
mlp_ratio: 2.5
use_pe: false
qk_norm: false
class_dropout_prob: 0.1
# PAG
pag_applied_layers:
- 8
# VAE setting
vae:
vae_type: dc-ae
vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
scale_factor: 0.41407
vae_latent_dim: 32
vae_downsample_rate: 32
sample_posterior: true
# text encoder
text_encoder:
text_encoder_name: gemma-2-2b-it
y_norm: true
y_norm_scale_factor: 0.01
model_max_length: 300
# CHI
chi_prompt:
- 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
- '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
- '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
- 'Here are examples of how to transform or refine prompts:'
- '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
- '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
- 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
- 'User Prompt: '
# Sana schedule Flow
scheduler:
predict_v: true
noise_schedule: linear_flow
pred_sigma: false
flow_shift: 3.0
# logit-normal timestep
weighting_scheme: logit_normal
logit_mean: 0.0
logit_std: 1.0
vis_sampler: flow_dpm-solver
# training setting
train:
num_workers: 10
seed: 1
train_batch_size: 64
num_epochs: 100
gradient_accumulation_steps: 1
grad_checkpointing: true
gradient_clip: 0.1
optimizer:
lr: 1.0e-4
type: AdamW
weight_decay: 0.01
eps: 1.0e-8
betas: [0.9, 0.999]
lr_schedule: constant
lr_schedule_args:
num_warmup_steps: 2000
local_save_vis: true # if save log image locally
visualize: true
eval_sampling_steps: 500
log_interval: 20
save_model_epochs: 5
save_model_steps: 500
work_dir: output/debug
online_metric: false
eval_metric_step: 2000
online_metric_dir: metric_helper
8 changes: 4 additions & 4 deletions diffusion/model/nets/sana_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
y: (N, 1, 120, C) tensor of class labels
"""
bs = x.shape[0]
dtype = x.dtype
timestep = timestep.to(dtype)
y = y.to(dtype)
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
if self.use_pe:
x = self.x_embedder(x)
Expand All @@ -296,7 +296,7 @@ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
)
.unsqueeze(0)
.to(x.device)
.to(dtype)
.to(self.dtype)
)
x += self.pos_embed_ms # (N, T, D), where T = H * W / patch_size ** 2
else:
Expand Down
1 change: 0 additions & 1 deletion scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def visualize(config, args, model, items, bs, sample_steps, cfg_scale, pag_scale
latent_size,
device=device,
generator=generator,
dtype=weight_dtype,
)
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)

Expand Down
8 changes: 7 additions & 1 deletion scripts/inference_dpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ def visualize(items, bs, sample_steps, cfg_scale, pag_scale=1.0):
with torch.no_grad():
n = len(prompts)
z = torch.randn(
n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator
n,
config.vae.vae_latent_dim,
latent_size,
latent_size,
device=device,
generator=generator,
)
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)

Expand Down Expand Up @@ -432,6 +437,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type)
os.makedirs(save_root, exist_ok=True)
if args.if_save_dirname and args.gpu_id == 0:
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
# save at work_dir/metrics/tmp_dpg_xxx.txt for metrics testing
with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f:
print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt")
Expand Down
8 changes: 7 additions & 1 deletion scripts/inference_geneval.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,12 @@ def visualize(sample_steps, cfg_scale, pag_scale):
with torch.no_grad():
n = len(prompts)
z = torch.randn(
n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator
n,
config.vae.vae_latent_dim,
latent_size,
latent_size,
device=device,
generator=generator,
)
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)

Expand Down Expand Up @@ -535,6 +540,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
save_root = create_save_root(args, args.dataset, epoch_name, step_name, sample_steps, guidance_type)
os.makedirs(save_root, exist_ok=True)
if args.if_save_dirname and args.gpu_id == 0:
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
# save at work_dir/metrics/tmp_geneval_xxx.txt for metrics testing
with open(f"{work_dir}/metrics/tmp_geneval_{time.time()}.txt", "w") as f:
print(f"save tmp file at {work_dir}/metrics/tmp_geneval_{time.time()}.txt")
Expand Down
16 changes: 11 additions & 5 deletions scripts/inference_image_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,14 @@ def visualize(items, bs, sample_steps, cfg_scale, pag_scale=1.0):
# start sampling
with torch.no_grad():
n = len(prompts)
z = torch.randn(n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator)
z = torch.randn(
n,
config.vae.vae_latent_dim,
latent_size,
latent_size,
device=device,
generator=generator,
)
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)

if args.sampling_algo == "dpm-solver":
Expand Down Expand Up @@ -205,7 +212,6 @@ def get_args():
class SanaInference(SanaConfig):
config: str = ""
model_path: Optional[str] = field(default=None, metadata={"help": "Path to the model file (optional)"})
version: str = "sigma"
txt_file: str = "asset/samples.txt"
json_file: Optional[str] = None
sample_nums: int = 100_000
Expand All @@ -214,7 +220,7 @@ class SanaInference(SanaConfig):
cfg_scale: float = 4.5
pag_scale: float = 1.0
sampling_algo: str = field(
default="dpm-solver", metadata={"choices": ["dpm-solver", "sa-solver", "flow_euler", "flow_dpm-solver"]}
default="flow_dpm-solver", metadata={"choices": ["dpm-solver", "sa-solver", "flow_euler", "flow_dpm-solver"]}
)
seed: int = 0
dataset: str = "custom"
Expand All @@ -233,7 +239,6 @@ class SanaInference(SanaConfig):
default=None, metadata={"help": "A list value, like [0, 1.] for ablation"}
)
ablation_key: Optional[str] = field(default=None, metadata={"choices": ["step", "cfg_scale", "pag_scale"]})
debug: bool = False
if_save_dirname: bool = field(
default=False,
metadata={"help": "if save img save dir name at wor_dir/metrics/tmp_time.time().txt for metric testing"},
Expand All @@ -244,7 +249,6 @@ class SanaInference(SanaConfig):

args = get_args()
config = args = pyrallis.parse(config_class=SanaInference, config_path=args.config)
# config = read_config(args.config)

args.image_size = config.model.image_size
if args.custom_image_size:
Expand Down Expand Up @@ -311,6 +315,7 @@ class SanaInference(SanaConfig):
"linear_head_dim": config.model.linear_head_dim,
"pred_sigma": pred_sigma,
"learn_sigma": learn_sigma,
"use_fp32_attention": getattr(config.model, "fp32_attention", False),
}
model = build_model(config.model.model, **model_kwargs).to(device)
logger.info(
Expand Down Expand Up @@ -411,6 +416,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type)
os.makedirs(save_root, exist_ok=True)
if args.if_save_dirname and args.gpu_id == 0:
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
# save at work_dir/metrics/tmp_xxx.txt for metrics testing
with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f:
print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt")
Expand Down
2 changes: 1 addition & 1 deletion scripts/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def generate_img(

n = len(prompts)
latent_size_h, latent_size_w = height // config.vae.vae_downsample_rate, width // config.vae.vae_downsample_rate
z = torch.randn(n, config.vae.vae_latent_dim, latent_size_h, latent_size_w, device=device, dtype=weight_dtype)
z = torch.randn(n, config.vae.vae_latent_dim, latent_size_h, latent_size_w, device=device)
model_kwargs = dict(data_info={"img_hw": (latent_size_h, latent_size_w), "aspect_ratio": 1.0}, mask=emb_masks)
print(f"Latent Size: {z.shape}")
# Sample images:
Expand Down

0 comments on commit 34bef75

Please sign in to comment.