Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wildprompt #2102

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ user_path_config-deprecated.txt
/package-lock.json
/.coverage*
/auth.json
civitapi.txt
35 changes: 32 additions & 3 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def worker():
import extras.face_crop
import fooocus_version

from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion
from modules.sdxl_styles import apply_style, apply_wildcards, apply_wildprompts, get_all_wildprompts, fooocus_expansion
from modules.private_logger import log
from extras.expansion import safe_str
from modules.util import remove_empty_str, HWC3, resize_image, \
Expand Down Expand Up @@ -121,6 +121,8 @@ def handler(async_task):

prompt = args.pop()
negative_prompt = args.pop()
wildprompt_selections = args.pop()
wildprompt_generate_all = args.pop()
style_selections = args.pop()
performance_selection = args.pop()
aspect_ratios_selection = args.pop()
Expand Down Expand Up @@ -153,6 +155,7 @@ def handler(async_task):
outpaint_selections = [o.lower() for o in outpaint_selections]
base_model_additional_loras = []
raw_style_selections = copy.deepcopy(style_selections)
raw_wildprompt_selections = copy.deepcopy(wildprompt_selections)
uov_method = uov_method.lower()

if fooocus_expansion in style_selections:
Expand All @@ -162,6 +165,7 @@ def handler(async_task):
use_expansion = False

use_style = len(style_selections) > 0
use_wildprompt = len(wildprompt_selections) > 0

if base_model_name == refiner_model_name:
print(f'Refiner disabled because base model and refiner are same.')
Expand Down Expand Up @@ -376,11 +380,35 @@ def handler(async_task):

progressbar(async_task, 3, 'Processing prompts ...')
tasks = []
for i in range(image_number):
wildprompts = []
wildprompt_count = len(wildprompt_selections)

# Get wildprompts if wildprompt_generate_all is enabled and there is only one wildprompt
if wildprompt_generate_all and wildprompt_count == 1:
wildprompts = get_all_wildprompts(wildprompt_selections)

if len(wildprompts) > 0:
totalprompts = len(wildprompts) * image_number
else:
totalprompts = image_number

for i in range(totalprompts):
task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
task_rng = random.Random(task_seed) # may bind to inpaint noise in the future

if len(wildprompts) > 0:
wildprompt_prompt = wildprompts[i // image_number]
else:
wildprompt_prompt = apply_wildprompts(wildprompt_selections, task_rng) if use_wildprompt else ''

wildprompt_prompt = apply_wildcards(wildprompt_prompt, task_rng)
task_prompt = apply_wildcards(prompt, task_rng)

if wildprompt_prompt and task_prompt:
task_prompt = f"{task_prompt}, {wildprompt_prompt}"
elif wildprompt_prompt:
task_prompt = wildprompt_prompt

task_negative_prompt = apply_wildcards(negative_prompt, task_rng)
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng) for pmt in extra_positive_prompts]
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng) for pmt in extra_negative_prompts]
Expand Down Expand Up @@ -779,6 +807,7 @@ def callback(step, x0, x, total_steps, y):
('Negative Prompt', task['log_negative_prompt']),
('Fooocus V2 Expansion', task['expansion']),
('Styles', str(raw_style_selections)),
('Wildprompts', str(raw_wildprompt_selections)),
('Performance', performance_selection),
('Resolution', str((width, height))),
('Sharpness', sharpness),
Expand All @@ -798,7 +827,7 @@ def callback(step, x0, x, total_steps, y):
if n != 'None':
d.append((f'LoRA {li + 1}', f'{n} : {w}'))
d.append(('Version', 'v' + fooocus_version.version))
log(x, d)
log(x, d, str(wildprompt_selections).replace("'", ""))

yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1)
except ldm_patched.modules.model_management.InterruptProcessingException as e:
Expand Down
5 changes: 5 additions & 0 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_
],
validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x)
)
default_wildprompts = get_config_item_or_set_default(
key='default_wildprompts',
default_value=[],
validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_wildprompt_names for y in x)
)
default_prompt_negative = get_config_item_or_set_default(
key='default_prompt_negative',
default_value='',
Expand Down
191 changes: 158 additions & 33 deletions modules/private_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,21 @@
from PIL import Image
from modules.util import generate_temp_filename


log_cache = {}


def get_current_html_path():
date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.config.path_outputs,
extension='png')
date_string, local_temp_filename, only_name, logpath = generate_temp_filename(folder=modules.config.path_outputs, extension='png')
html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html')
return html_name


def log(img, dic):
def log(img, dic, wildprompt=''):
if args_manager.args.disable_image_log:
return

date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.config.path_outputs, extension='png')
date_string, local_temp_filename, only_name, logpath = generate_temp_filename(folder=modules.config.path_outputs, extension='png', wildprompt=wildprompt)
os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True)
Image.fromarray(img).save(local_temp_filename)
html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html')
html_name = os.path.join(os.path.dirname(logpath), 'log.html')

css_styles = (
"<style>"
Expand All @@ -40,35 +36,156 @@ def log(img, dic):
"hr { border-color: gray; } "
"button { background-color: black; color: white; border: 1px solid grey; border-radius: 5px; padding: 5px 10px; text-align: center; display: inline-block; font-size: 16px; cursor: pointer; }"
"button:hover {background-color: grey; color: black;}"
"#filters { display: flex; flex-wrap: wrap; gap: 2rem; padding: 2rem; }"
".filter-heading { font-weight: bold; font-size: 1.2rem; margin-bottom: 0.5em;}"
"label { display: block; margin-bottom: 0.5em; cursor: pointer; }"
"</style>"
)

js = (
"""<script>
function to_clipboard(txt) {
txt = decodeURIComponent(txt);
if (navigator.clipboard && navigator.permissions) {
navigator.clipboard.writeText(txt)
} else {
const textArea = document.createElement('textArea')
textArea.value = txt
textArea.style.width = 0
textArea.style.position = 'fixed'
textArea.style.left = '-999px'
textArea.style.top = '10px'
textArea.setAttribute('readonly', 'readonly')
document.body.appendChild(textArea)

textArea.select()
document.execCommand('copy')
document.body.removeChild(textArea)
}
alert('Copied to Clipboard!\\nPaste to prompt area to load parameters.\\nCurrent clipboard content is:\\n\\n' + txt);
}
</script>"""
"""
<script>
function to_clipboard(txt) {
txt = decodeURIComponent(txt);
if (navigator.clipboard && navigator.permissions) {
navigator.clipboard.writeText(txt)
} else {
const textArea = document.createElement('textArea')
textArea.value = txt
textArea.style.width = 0
textArea.style.position = 'fixed'
textArea.style.left = '-999px'
textArea.style.top = '10px'
textArea.setAttribute('readonly', 'readonly')
document.body.appendChild(textArea)

textArea.select()
document.execCommand('copy')
document.body.removeChild(textArea)
}
alert('Copied to Clipboard!\\nPaste to prompt area to load parameters.\\nCurrent clipboard content is:\\n\\n' + txt);
}

// Function to update visibility of log items based on filters
function updateFilters() {
var baseModelFilters = document.querySelectorAll('input[name="baseModelFilter"]:checked');
var wildpromptFilters = document.querySelectorAll('input[name="wildpromptFilter"]:checked');

// Loop through all log items
var logItems = document.querySelectorAll('.image-container');
logItems.forEach(function(item) {
var baseModel = item.getAttribute('data-model');
var wildpromptsAttr = item.getAttribute('data-wildprompts');
var wildprompts = [];
if (wildpromptsAttr && wildpromptsAttr != '[]') {
wildprompts = wildpromptsAttr.match(/'[^']+'/g);
if (wildprompts) {
wildprompts = wildprompts.map(function(item) {
return item.replace(/'/g, '');
});
}
}

var isVisible = true;

// Check if base model filter is active
if (baseModelFilters.length > 0) {
var baseModelMatch = false;
baseModelFilters.forEach(function(filter) {
if (baseModel === filter.value) {
baseModelMatch = true;
}
});
if (!baseModelMatch) {
isVisible = false;
}
}

// Check if wildprompt filter is active
if (wildpromptFilters.length > 0) {
var wildpromptMatch = false;
wildpromptFilters.forEach(function(filter) {
if (wildprompts.includes(filter.value)) {
wildpromptMatch = true;
}
});
if (!wildpromptMatch) {
isVisible = false;
}
}

// Update visibility
if (isVisible) {
item.style.display = 'block';
} else {
item.style.display = 'none';
}
});
}

// Function to initialize filters
function initFilters() {
// Base model filter
var baseModels = {};
var baseModelCheckboxes = document.getElementById('baseModelFilters');
var baseModelOptions = document.querySelectorAll('.image-container');
baseModelOptions.forEach(function(item) {
var baseModel = item.getAttribute('data-model');
if (!baseModels.hasOwnProperty(baseModel)) {
baseModels[baseModel] = 0;
}
baseModels[baseModel]++;
});
for (var model in baseModels) {
var checkbox = document.createElement('input');
checkbox.type = 'checkbox';
checkbox.name = 'baseModelFilter';
checkbox.value = model;
checkbox.addEventListener('change', updateFilters);
var label = document.createElement('label');
label.appendChild(checkbox);
label.appendChild(document.createTextNode(model + ' (' + baseModels[model] + ')'));
baseModelCheckboxes.appendChild(label);
}

// Wildprompt filter
var wildpromptCheckboxes = document.getElementById('wildpromptFilters');
var wildprompts = {};
baseModelOptions.forEach(function(item) {
var wildpromptsAttr = item.getAttribute('data-wildprompts');
var prompts = [];
if (wildpromptsAttr && wildpromptsAttr != '[]') {
prompts = wildpromptsAttr.match(/'[^']+'/g).map(function(item) {
return item.replace(/'/g, '');
});
}
prompts.forEach(function(prompt) {
if (!wildprompts.hasOwnProperty(prompt)) {
wildprompts[prompt] = 0;
}
wildprompts[prompt]++;
});
});
for (var prompt in wildprompts) {
var checkbox = document.createElement('input');
checkbox.type = 'checkbox';
checkbox.name = 'wildpromptFilter';
checkbox.value = prompt;
checkbox.addEventListener('change', updateFilters);
var label = document.createElement('label');
label.appendChild(checkbox);
label.appendChild(document.createTextNode(prompt + ' (' + wildprompts[prompt] + ')'));
wildpromptCheckboxes.appendChild(label);
}
}

// Initialize filters when the page is loaded
window.addEventListener('load', initFilters);
</script>
"""
)

begin_part = f"<!DOCTYPE html><html><head><title>Fooocus Log {date_string}</title>{css_styles}</head><body>{js}<p>Fooocus Log {date_string} (private)</p>\n<p>All images are clean, without any hidden data/meta, and safe to share with others.</p><!--fooocus-log-split-->\n\n"
begin_part = f"<!DOCTYPE html><html><head><title>Fooocus Log {date_string}</title>\n\n{css_styles}\n\n</head><body>\n\n{js}\n\n<div id=\"filters\"><div id=\"baseModelFilters\"><div class=\"filter-heading\">Base Model</div></div><div id=\"wildpromptFilters\"><div class=\"filter-heading\">Wildprompts</div></div>\n\n</div><!--fooocus-log-split-->\n\n"
end_part = f'\n<!--fooocus-log-split--></body></html>'

middle_part = log_cache.get(html_name, "")
Expand All @@ -82,8 +199,16 @@ def log(img, dic):
middle_part = existing_split[0]

div_name = only_name.replace('.', '_')
item = f"<div id=\"{div_name}\" class=\"image-container\"><hr><table><tr>\n"
item += f"<td><a href=\"{only_name}\" target=\"_blank\"><img src='{only_name}' onerror=\"this.closest('.image-container').style.display='none';\" loading='lazy'/></a><div>{only_name}</div></td>"
for key, value in dic:
if key == 'Base Model':
base_model = value
break
for key, value in dic:
if key == 'Wildprompts':
wildprompts = value
break
item = f"<div id=\"{div_name}\" class=\"image-container\" data-model=\"{base_model}\" data-wildprompts=\"{wildprompts}\"><hr><table><tr>\n"
item += f"<td><a href=\"{only_name}\" target=\"_blank\"><img src='{only_name}' onerror=\"this.closest('.image-container').remove();\"/></a><div>{only_name}</div></td>"
item += "<td><table class='metadata'>"
for key, value in dic:
value_txt = str(value).replace('\n', ' <br/> ')
Expand Down
37 changes: 36 additions & 1 deletion modules/sdxl_styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# cannot use modules.config - validators causing circular imports
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/'))
wildprompts_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildprompts/'))
wildcards_max_bfs_depth = 64


Expand All @@ -24,8 +25,14 @@ def normalize_key(k):


styles = {}
wildprompts = {}

styles_files = get_files_from_folder(styles_path, ['.json'])
wildcards_files = get_files_from_folder(wildcards_path, ['.txt'])

def GetLegalWildpromptNames():
wildprompts_files = get_files_from_folder(wildprompts_path, ['.txt'])
return [os.path.splitext(file)[0] for file in wildprompts_files]

for x in ['sdxl_styles_fooocus.json',
'sdxl_styles_sai.json',
Expand All @@ -52,7 +59,8 @@ def normalize_key(k):
style_keys = list(styles.keys())
fooocus_expansion = "Fooocus V2"
legal_style_names = [fooocus_expansion] + style_keys

legal_wildcard_names = [os.path.splitext(file)[0] for file in wildcards_files]
legal_wildprompt_names = GetLegalWildpromptNames()

def apply_style(style, positive):
p, n = styles[style]
Expand Down Expand Up @@ -80,3 +88,30 @@ def apply_wildcards(wildcard_text, rng, directory=wildcards_path):

print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
return wildcard_text

def apply_wildprompts(wildprompt_selections, rng):
prompts = []

for wildprompt_selection in wildprompt_selections:
try:
wildprompt_text = open(os.path.join(wildprompts_path, f'{wildprompt_selection}.txt'), encoding='utf-8').read().splitlines()
wildprompt_text = [x for x in wildprompt_text if x != '']
assert len(wildprompt_text) > 0
prompts.append(rng.choice(wildprompt_text))
except:
print(f'[Wildprompts] Warning: {wildprompt_selection}.txt missing or empty. ')

return ', '.join(prompts)

def get_all_wildprompts(wildprompt_selections):
prompts = []

try:
wildprompt_text = open(os.path.join(wildprompts_path, f'{wildprompt_selections[0]}.txt'), encoding='utf-8').read().splitlines()
wildprompt_text = [x for x in wildprompt_text if x != '']
assert len(wildprompt_text) > 0
prompts.extend(wildprompt_text)
except:
print(f'[Wildprompts] Warning: {wildprompt_selections[0]}.txt missing or empty. ')

return prompts
Loading