Skip to content

Commit

Permalink
add keyword identifier for loras
Browse files Browse the repository at this point in the history
  • Loading branch information
runew0lf committed Dec 5, 2023
1 parent fe342e2 commit f91ecf7
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 36 deletions.
4 changes: 2 additions & 2 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def process(gen_data):
outputs.append(["preview", (-1, f"Loading LoRA models ...", None)])
lora_keywords = pipeline.load_loras(loras)
if lora_keywords is None:
lora_keywords = " "
lora_keywords = ""

if gen_data["performance_selection"] == NEWPERF:
steps = gen_data["custom_steps"]
Expand Down Expand Up @@ -212,7 +212,7 @@ def callback(step, x0, x, total_steps, y):
p_txt, n_txt = process_prompt(
gen_data["style_selection"], pos_stripped, neg_stripped, gen_data
)
p_txt += lora_keywords
p_txt = lora_keywords + p_txt
start_step = 0
denoise = None
with TimeIt("Pipeline process"):
Expand Down
9 changes: 7 additions & 2 deletions modules/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_abspath(path):
extensions = [".pth", ".ckpt", ".bin", ".safetensors"]


def get_model_filenames(folder_path):
def get_model_filenames(folder_path, isLora=False):
if not os.path.isdir(folder_path):
raise ValueError("Folder path is not a valid directory.")

Expand All @@ -83,6 +83,11 @@ def get_model_filenames(folder_path):
_, ext = os.path.splitext(filename)
if ext.lower() in [".pth", ".ckpt", ".bin", ".safetensors"]:
path = os.path.join(relative_path, filename)
if isLora:
txtcheck = path.replace(".safetensors", ".txt")
if os.path.isfile(f"{folder_path}{txtcheck}"):
path = path + " 🗒️"

filenames.append(path)

return sorted(
Expand All @@ -94,7 +99,7 @@ def get_model_filenames(folder_path):
def update_all_model_names():
global model_filenames, lora_filenames, upscaler_filenames
model_filenames = get_model_filenames(modelfile_path)
lora_filenames = get_model_filenames(lorafile_path)
lora_filenames = get_model_filenames(lorafile_path, True)
upscaler_filenames = get_model_filenames(upscaler_path)
return

Expand Down
63 changes: 32 additions & 31 deletions modules/sdxl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,14 @@ def load_all_keywords(self, loras):
for name, weight in loras:
if name == "None" or weight == 0:
continue

name = name.strip(" 🗒️")
filename = os.path.join(modules.path.lorafile_path, name)
lora_prompt_addition = (
f"{lora_prompt_addition}, {self.load_keywords(filename)}"
f"{lora_prompt_addition} {self.load_keywords(filename)}, "
)
return lora_prompt_addition

# name = name.strip(" 🗒️")
def load_loras(self, loras):
lora_prompt_addition = self.load_all_keywords(loras)
if self.xl_base_patched_hash == str(loras):
Expand All @@ -182,7 +183,7 @@ def load_loras(self, loras):
for name, weight in loras:
if name == "None" or weight == 0:
continue

name = name.strip(" 🗒️")
filename = os.path.join(modules.path.lorafile_path, name)
print(f"Loading LoRAs: {name}")
with suppress_stdout():
Expand Down Expand Up @@ -232,7 +233,6 @@ def clean_prompt_cond_caches(self):
self.conditions["-"]["cache"] = None
self.conditions["switch"]["text"] = None
self.conditions["switch"]["cache"] = None


def textencode(self, id, text):
update = False
Expand All @@ -244,39 +244,42 @@ def textencode(self, id, text):
self.conditions[id]["text"] = text
update = True
return update

def set_timestep_range(self, conditioning, start, end):
c = []
for t in conditioning:
if 'pooled_output' in t:
t['start_percent'] = start
t['end_percent'] = end
if "pooled_output" in t:
t["start_percent"] = start
t["end_percent"] = end

return conditioning



def prompt_switch_per_step(self, prompt, steps):
# Find all occurrences of [option1|option2|...] in the input string
options_pattern = r'\[([^|\]]+(?:\|[^|\]]+)*)\]'
options_pattern = r"\[([^|\]]+(?:\|[^|\]]+)*)\]"
matches = re.finditer(options_pattern, prompt)
options_list = []
exact_matches = []

for match in matches:
options = match.group(1).split('|') if '|' in match.group(1) else [match.group(1)]
options = (
match.group(1).split("|") if "|" in match.group(1) else [match.group(1)]
)
options_list.append(options)
exact_matches.append(match.group(0))

prompt_per_step = []
for i in range(0,steps):
for i in range(0, steps):
prompt_to_append = prompt

for options, exact_match in zip(options_list, exact_matches):
replacement = options[i % len(options)] # Use modulo to cycle through options
prompt_to_append = prompt_to_append.replace(exact_match, replacement, 1)
replacement = options[
i % len(options)
] # Use modulo to cycle through options
prompt_to_append = prompt_to_append.replace(exact_match, replacement, 1)

prompt_per_step.append(prompt_to_append)

return prompt_per_step

@torch.inference_mode()
Expand Down Expand Up @@ -307,35 +310,33 @@ def process(
if self.conditions is None:
self.clean_prompt_cond_caches()



if self.textencode("+", positive_prompt):
updated_conditions = True
if self.textencode("-", negative_prompt):
updated_conditions = True

prompt_switch_mode = False
if("|" in positive_prompt):
if "|" in positive_prompt:
prompt_switch_mode = True
prompt_per_step = self.prompt_switch_per_step(positive_prompt, steps)
perc_per_step = round(100/steps,2)

perc_per_step = round(100 / steps, 2)
positive_complete = []
for i in range(len(prompt_per_step)):
if self.textencode("switch", prompt_per_step[i]):
updated_conditions = True
positive_switch = convert_cond(self.conditions["switch"]["cache"])
start_perc = round((perc_per_step*i)/100,2)
end_perc = round((perc_per_step*(i+1))/100,2)
start_perc = round((perc_per_step * i) / 100, 2)
end_perc = round((perc_per_step * (i + 1)) / 100, 2)
if end_perc >= 0.99:
end_perc = 1
positive_switch = self.set_timestep_range(positive_switch,start_perc, end_perc)

positive_switch = self.set_timestep_range(
positive_switch, start_perc, end_perc
)

positive_complete += positive_switch


positive_switch = convert_cond(self.conditions["switch"]["cache"])

positive_switch = convert_cond(self.conditions["switch"]["cache"])

if controlnet is not None and input_image is not None:
worker.outputs.append(["preview", (-1, f"Powering up ...", None)])
Expand All @@ -347,7 +348,7 @@ def process(
)[0]
self.refresh_controlnet(name=controlnet["type"])
if self.xl_controlnet:
if(prompt_switch_mode):
if prompt_switch_mode:
match controlnet["type"].lower():
case "canny":
input_image = Canny().detect_edge(
Expand Down
3 changes: 3 additions & 0 deletions update_log.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
### 1.25.1
* Now displays 🗒️ at the end of the loraname if it has a keyword .txt file

### 1.25.0
* Prompt Bracketing Now works ie `[cat|dog]`
* Updated comfy version
Expand Down
2 changes: 1 addition & 1 deletion version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.25.0"
version = "1.25.1"

0 comments on commit f91ecf7

Please sign in to comment.