Skip to content

Commit

Permalink
feat: use "AllModelInfo" to get model info
Browse files Browse the repository at this point in the history
  • Loading branch information
linjiX committed Oct 12, 2024
1 parent 5c15a24 commit a683d3e
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 12 deletions.
13 changes: 2 additions & 11 deletions iz_helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,15 @@ def closest_upper_divisible_by_eight(num):
return math.ceil(num / 8) * 8


def load_model_from_setting(model_field_name, progress, progress_desc, specified_model=None):
def load_model_from_setting(model_field_name, progress, progress_desc, all_model_info, specified_model=None):
# fix typo in Automatic1111 vs Vlad111
if hasattr(modules.sd_models, "checkpoint_alisases"):
checkPList = modules.sd_models.checkpoint_alisases
elif hasattr(modules.sd_models, "checkpoint_aliases"):
checkPList = modules.sd_models.checkpoint_aliases
else:
raise Exception(
"This is not a compatible StableDiffusion Platform, can not access checkpoints"
)

if specified_model:
model_name = specified_model
else:
model_name = shared.opts.data.get(model_field_name)

if model_name is not None and model_name != "":
checkinfo = checkPList[model_name]
checkinfo = all_model_info.checkpoint_models[model_name]

if not checkinfo:
raise NameError(model_field_name + " Does not exist in your models.")
Expand Down
13 changes: 13 additions & 0 deletions iz_helpers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import modules.shared as shared
from modules.paths_internal import script_path
from modules.paths import Paths
from modules.model_info import AllModelInfo
from .helpers import (
fix_env_Path_ffprobe,
closest_upper_divisible_by_eight,
Expand Down Expand Up @@ -94,6 +95,7 @@ def outpaint_steps(
mask_width,
mask_height,
custom_exit_image,
all_model_info,
frame_correction=True, # TODO: add frame_Correction in UI
):
main_frames = [init_img.convert("RGB")]
Expand Down Expand Up @@ -145,6 +147,7 @@ def outpaint_steps(
inpainting_fill_mode,
inpainting_full_res,
inpainting_padding,
all_model_info,
)

if len(processed.images) > 0:
Expand Down Expand Up @@ -177,6 +180,7 @@ def create_zoom(
request: gr.Request,
id_task,
model_title: str,
raw_model_info: str,
common_prompt_pre,
prompts_array,
common_prompt_suf,
Expand Down Expand Up @@ -240,6 +244,7 @@ def create_zoom(
upscaler_name,
upscale_by,
main_sd_model,
raw_model_info,
inpainting_denoising_strength,
inpainting_full_res,
inpainting_padding,
Expand Down Expand Up @@ -332,6 +337,7 @@ def create_zoom_single(
upscaler_name,
upscale_by,
main_sd_model,
raw_model_info,
inpainting_denoising_strength,
inpainting_full_res,
inpainting_padding,
Expand Down Expand Up @@ -367,6 +373,9 @@ def create_zoom_single(
current_image = current_image.convert("RGB")
current_seed = seed

all_model_info = AllModelInfo(raw_model_info)
all_model_info.check_file_existence()

if custom_init_image:
current_image = custom_init_image.resize(
(width, height), resample=Image.LANCZOS
Expand All @@ -378,6 +387,7 @@ def create_zoom_single(
"infzoom_txt2img_model",
progress,
"Loading Model for txt2img: ",
all_model_info,
specified_model=main_sd_model
)

Expand All @@ -392,6 +402,7 @@ def create_zoom_single(
current_seed,
width,
height,
all_model_info,
)
if len(processed.images) > 0:
current_image = processed.images[0]
Expand All @@ -412,6 +423,7 @@ def create_zoom_single(
"infzoom_inpainting_model",
progress,
"Loading Model for inpainting/img2img: ",
all_model_info,
specified_model=main_sd_model
)
main_frames, processed = outpaint_steps(
Expand All @@ -437,6 +449,7 @@ def create_zoom_single(
mask_width,
mask_height,
custom_exit_image,
all_model_info,
)
all_frames.append(
do_upscaleImg(main_frames[0], upscale_do, upscaler_name, upscale_by)
Expand Down
5 changes: 4 additions & 1 deletion iz_helpers/sd_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def renderTxt2Img(
request: gr.Request, prompt, negative_prompt, sampler, steps, cfg_scale, seed, width, height
request: gr.Request, prompt, negative_prompt, sampler, steps, cfg_scale, seed, width, height, all_model_info
):
processed = None
p = StableDiffusionProcessingTxt2Img(
Expand All @@ -31,6 +31,7 @@ def renderTxt2Img(
height=height,
)
p.set_request(request)
p.set_all_model_info(all_model_info)
with monitor_call_context(
request,
get_function_name_from_processing(p),
Expand Down Expand Up @@ -61,6 +62,7 @@ def renderImg2Img(
inpainting_fill_mode,
inpainting_full_res,
inpainting_padding,
all_model_info,
):
processed = None

Expand All @@ -87,6 +89,7 @@ def renderImg2Img(
)
# p.latent_mask = Image.new("RGB", (p.width, p.height), "white")
p.set_request(request)
p.set_all_model_info(all_model_info)

with monitor_call_context(
request,
Expand Down
4 changes: 4 additions & 0 deletions iz_helpers/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,12 @@ def on_ui_tabs():
value=jpr["negPrompt"], label="Negative Prompt"
)


main_sd_model = gr.Textbox(
label="stable diffusion checkpoints", value="", visible=False, elem_id="infzoom_sd_model")

raw_model_info = gr.Label(visible=False)

# these button will be moved using JS under the dataframe view as small ones
exportPrompts_button = gr.Button(
value="Export prompts",
Expand Down Expand Up @@ -310,6 +313,7 @@ def on_ui_tabs():
inputs=[
id_task,
main_sd_model,
raw_model_info,
main_common_prompt_pre,
main_prompts,
main_common_prompt_suf,
Expand Down
16 changes: 16 additions & 0 deletions javascript/infinite-zoom.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ function exportPrompts(cppre,p, cpsuf,np, filename = "infinite-zoom-prompts.json
}
}

function _make_value_source(value) {
return {value: value, source: "infinite_zoom"};
}

async function iz_get_all_model_info(model_title, res) {
const checkpoint_titles = [_make_value_source(model_title)];
const prompts = [
_make_value_source(res[3]),
_make_value_source(res[5]),
_make_value_source(res[6]),
...res[4].data.map((item) => _make_value_source(item[1]))
];
return await getAllModelInfoByCheckpointsAndPrompts(checkpoint_titles, prompts);
}

async function iz_submit() {
addGenerateGtagEvent("#iz_submit_button > span", "#iz_generate_button");
await tierCheckButtonInternal("InfiniteZoom");
Expand All @@ -32,6 +47,7 @@ async function iz_submit() {
var res = Array.from(arguments);
res[0] = id;
res[1] = `model_title(${mainModel.value})`;
res[2] = JSON.stringify(await iz_get_all_model_info(mainModel.value, res));

return res;
}
Expand Down

0 comments on commit a683d3e

Please sign in to comment.