Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Muinez authored Nov 26, 2024
2 parents a83189a + 7d0d659 commit 7f2bec4
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 9 deletions.
1 change: 0 additions & 1 deletion scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def visualize(config, args, model, items, bs, sample_steps, cfg_scale, pag_scale
latent_size,
device=device,
generator=generator,
dtype=weight_dtype,
)
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)

Expand Down
8 changes: 7 additions & 1 deletion scripts/inference_dpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ def visualize(items, bs, sample_steps, cfg_scale, pag_scale=1.0):
with torch.no_grad():
n = len(prompts)
z = torch.randn(
n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator
n,
config.vae.vae_latent_dim,
latent_size,
latent_size,
device=device,
generator=generator,
)
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)

Expand Down Expand Up @@ -432,6 +437,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type)
os.makedirs(save_root, exist_ok=True)
if args.if_save_dirname and args.gpu_id == 0:
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
# save at work_dir/metrics/tmp_dpg_xxx.txt for metrics testing
with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f:
print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt")
Expand Down
8 changes: 7 additions & 1 deletion scripts/inference_geneval.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,12 @@ def visualize(sample_steps, cfg_scale, pag_scale):
with torch.no_grad():
n = len(prompts)
z = torch.randn(
n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator
n,
config.vae.vae_latent_dim,
latent_size,
latent_size,
device=device,
generator=generator,
)
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)

Expand Down Expand Up @@ -535,6 +540,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
save_root = create_save_root(args, args.dataset, epoch_name, step_name, sample_steps, guidance_type)
os.makedirs(save_root, exist_ok=True)
if args.if_save_dirname and args.gpu_id == 0:
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
# save at work_dir/metrics/tmp_geneval_xxx.txt for metrics testing
with open(f"{work_dir}/metrics/tmp_geneval_{time.time()}.txt", "w") as f:
print(f"save tmp file at {work_dir}/metrics/tmp_geneval_{time.time()}.txt")
Expand Down
16 changes: 11 additions & 5 deletions scripts/inference_image_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,14 @@ def visualize(items, bs, sample_steps, cfg_scale, pag_scale=1.0):
# start sampling
with torch.no_grad():
n = len(prompts)
z = torch.randn(n, config.vae.vae_latent_dim, latent_size, latent_size, device=device, generator=generator)
z = torch.randn(
n,
config.vae.vae_latent_dim,
latent_size,
latent_size,
device=device,
generator=generator,
)
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)

if args.sampling_algo == "dpm-solver":
Expand Down Expand Up @@ -205,7 +212,6 @@ def get_args():
class SanaInference(SanaConfig):
config: str = ""
model_path: Optional[str] = field(default=None, metadata={"help": "Path to the model file (optional)"})
version: str = "sigma"
txt_file: str = "asset/samples.txt"
json_file: Optional[str] = None
sample_nums: int = 100_000
Expand All @@ -214,7 +220,7 @@ class SanaInference(SanaConfig):
cfg_scale: float = 4.5
pag_scale: float = 1.0
sampling_algo: str = field(
default="dpm-solver", metadata={"choices": ["dpm-solver", "sa-solver", "flow_euler", "flow_dpm-solver"]}
default="flow_dpm-solver", metadata={"choices": ["dpm-solver", "sa-solver", "flow_euler", "flow_dpm-solver"]}
)
seed: int = 0
dataset: str = "custom"
Expand All @@ -233,7 +239,6 @@ class SanaInference(SanaConfig):
default=None, metadata={"help": "A list value, like [0, 1.] for ablation"}
)
ablation_key: Optional[str] = field(default=None, metadata={"choices": ["step", "cfg_scale", "pag_scale"]})
debug: bool = False
if_save_dirname: bool = field(
default=False,
metadata={"help": "if save img save dir name at wor_dir/metrics/tmp_time.time().txt for metric testing"},
Expand All @@ -244,7 +249,6 @@ class SanaInference(SanaConfig):

args = get_args()
config = args = pyrallis.parse(config_class=SanaInference, config_path=args.config)
# config = read_config(args.config)

args.image_size = config.model.image_size
if args.custom_image_size:
Expand Down Expand Up @@ -311,6 +315,7 @@ class SanaInference(SanaConfig):
"linear_head_dim": config.model.linear_head_dim,
"pred_sigma": pred_sigma,
"learn_sigma": learn_sigma,
"use_fp32_attention": getattr(config.model, "fp32_attention", False),
}
model = build_model(config.model.model, **model_kwargs).to(device)
logger.info(
Expand Down Expand Up @@ -411,6 +416,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type)
os.makedirs(save_root, exist_ok=True)
if args.if_save_dirname and args.gpu_id == 0:
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
# save at work_dir/metrics/tmp_xxx.txt for metrics testing
with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f:
print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt")
Expand Down
2 changes: 1 addition & 1 deletion scripts/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def generate_img(

n = len(prompts)
latent_size_h, latent_size_w = height // config.vae.vae_downsample_rate, width // config.vae.vae_downsample_rate
z = torch.randn(n, config.vae.vae_latent_dim, latent_size_h, latent_size_w, device=device, dtype=weight_dtype)
z = torch.randn(n, config.vae.vae_latent_dim, latent_size_h, latent_size_w, device=device)
model_kwargs = dict(data_info={"img_hw": (latent_size_h, latent_size_w), "aspect_ratio": 1.0}, mask=emb_masks)
print(f"Latent Size: {z.shape}")
# Sample images:
Expand Down

0 comments on commit 7f2bec4

Please sign in to comment.