mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui
synced 2025-01-08 12:07:30 +08:00
extra networks UI
rework of hypernets: rather than via settings, hypernets are added directly to prompt as <hypernet:name:weight>
This commit is contained in:
parent
e33cace2c2
commit
40ff6db532
BIN
html/card-no-preview.png
Normal file
BIN
html/card-no-preview.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 82 KiB |
11
html/extra-networks-card.html
Normal file
11
html/extra-networks-card.html
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
<div class='card' {preview_html} onclick='return cardClicked({prompt}, {allow_negative_prompt})'>
|
||||||
|
<div class='actions'>
|
||||||
|
<div class='additional'>
|
||||||
|
<ul>
|
||||||
|
<a href="#" title="replace preview image with currently selected in gallery" onclick='return saveCardPreview(event, {tabname}, {local_preview})'>replace preview</a>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
<span class='name'>{name}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
8
html/extra-networks-no-cards.html
Normal file
8
html/extra-networks-no-cards.html
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<div class='nocards'>
|
||||||
|
<h1>Nothing here. Add some content to the following directories:</h1>
|
||||||
|
|
||||||
|
<ul>
|
||||||
|
{dirs}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
60
javascript/extraNetworks.js
Normal file
60
javascript/extraNetworks.js
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
|
||||||
|
function setupExtraNetworksForTab(tabname){
|
||||||
|
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
||||||
|
|
||||||
|
gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh'))
|
||||||
|
gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close'))
|
||||||
|
}
|
||||||
|
|
||||||
|
var activePromptTextarea = null;
|
||||||
|
var activePositivePromptTextarea = null;
|
||||||
|
|
||||||
|
function setupExtraNetworks(){
|
||||||
|
setupExtraNetworksForTab('txt2img')
|
||||||
|
setupExtraNetworksForTab('img2img')
|
||||||
|
|
||||||
|
function registerPrompt(id, isNegative){
|
||||||
|
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
||||||
|
|
||||||
|
if (activePromptTextarea == null){
|
||||||
|
activePromptTextarea = textarea
|
||||||
|
}
|
||||||
|
if (activePositivePromptTextarea == null && ! isNegative){
|
||||||
|
activePositivePromptTextarea = textarea
|
||||||
|
}
|
||||||
|
|
||||||
|
textarea.addEventListener("focus", function(){
|
||||||
|
activePromptTextarea = textarea;
|
||||||
|
if(! isNegative) activePositivePromptTextarea = textarea;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
registerPrompt('txt2img_prompt')
|
||||||
|
registerPrompt('txt2img_neg_prompt', true)
|
||||||
|
registerPrompt('img2img_prompt')
|
||||||
|
registerPrompt('img2img_neg_prompt', true)
|
||||||
|
}
|
||||||
|
|
||||||
|
onUiLoaded(setupExtraNetworks)
|
||||||
|
|
||||||
|
function cardClicked(textToAdd, allowNegativePrompt){
|
||||||
|
textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea
|
||||||
|
|
||||||
|
textarea.value = textarea.value + " " + textToAdd
|
||||||
|
updateInput(textarea)
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
function saveCardPreview(event, tabname, filename){
|
||||||
|
textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
|
||||||
|
button = gradioApp().getElementById(tabname + '_save_preview')
|
||||||
|
|
||||||
|
textarea.value = filename
|
||||||
|
updateInput(textarea)
|
||||||
|
|
||||||
|
button.click()
|
||||||
|
|
||||||
|
event.stopPropagation()
|
||||||
|
event.preventDefault()
|
||||||
|
}
|
@ -21,6 +21,8 @@ titles = {
|
|||||||
"\U0001F5D1": "Clear prompt",
|
"\U0001F5D1": "Clear prompt",
|
||||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||||
"\u{1f4d2}": "Paste available values into the field",
|
"\u{1f4d2}": "Paste available values into the field",
|
||||||
|
"\u{1f3b4}": "Show extra networks",
|
||||||
|
|
||||||
|
|
||||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||||
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
||||||
|
@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
|
|||||||
return [prompt, negative_prompt]
|
return [prompt, negative_prompt]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
opts = {}
|
opts = {}
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
if(Object.keys(opts).length != 0) return;
|
if(Object.keys(opts).length != 0) return;
|
||||||
@ -239,11 +237,14 @@ onUiUpdate(function(){
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
prompt.parentElement.insertBefore(counter, prompt)
|
prompt.parentElement.insertBefore(counter, prompt)
|
||||||
counter.classList.add("token-counter")
|
counter.classList.add("token-counter")
|
||||||
prompt.parentElement.style.position = "relative"
|
prompt.parentElement.style.position = "relative"
|
||||||
|
|
||||||
textarea.addEventListener("input", () => update_token_counter(id_button));
|
textarea.addEventListener("input", function(){
|
||||||
|
update_token_counter(id_button);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
|
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
|
||||||
@ -261,10 +262,8 @@ onUiUpdate(function(){
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
onOptionsChanged(function(){
|
onOptionsChanged(function(){
|
||||||
elem = gradioApp().getElementById('sd_checkpoint_hash')
|
elem = gradioApp().getElementById('sd_checkpoint_hash')
|
||||||
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
||||||
|
@ -480,7 +480,7 @@ class Api:
|
|||||||
def train_hypernetwork(self, args: dict):
|
def train_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
initial_hypernetwork = shared.loaded_hypernetwork
|
shared.loaded_hypernetworks = []
|
||||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||||
error = None
|
error = None
|
||||||
filename = ''
|
filename = ''
|
||||||
@ -491,16 +491,15 @@ class Api:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = e
|
error = e
|
||||||
finally:
|
finally:
|
||||||
shared.loaded_hypernetwork = initial_hypernetwork
|
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
|
||||||
except AssertionError as msg:
|
except AssertionError as msg:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "train embedding error: {error}".format(error = error))
|
return TrainResponse(info="train embedding error: {error}".format(error=error))
|
||||||
|
|
||||||
def get_memory(self):
|
def get_memory(self):
|
||||||
try:
|
try:
|
||||||
|
147
modules/extra_networks.py
Normal file
147
modules/extra_networks.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from modules import errors
|
||||||
|
|
||||||
|
extra_network_registry = {}
|
||||||
|
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
extra_network_registry.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def register_extra_network(extra_network):
|
||||||
|
extra_network_registry[extra_network.name] = extra_network
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraNetworkParams:
|
||||||
|
def __init__(self, items=None):
|
||||||
|
self.items = items or []
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraNetwork:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def activate(self, p, params_list):
|
||||||
|
"""
|
||||||
|
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
|
||||||
|
Passes arguments related to this extra network in params_list.
|
||||||
|
User passes arguments by specifying this in his prompt:
|
||||||
|
|
||||||
|
<name:arg1:arg2:arg3>
|
||||||
|
|
||||||
|
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
||||||
|
separated by colon.
|
||||||
|
|
||||||
|
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
|
||||||
|
in this case, all effects of this extra networks should be disabled.
|
||||||
|
|
||||||
|
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
||||||
|
|
||||||
|
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
|
||||||
|
|
||||||
|
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
|
||||||
|
|
||||||
|
params_list will be:
|
||||||
|
|
||||||
|
[
|
||||||
|
ExtraNetworkParams(items=["agm", "1.1"]),
|
||||||
|
ExtraNetworkParams(items=["ray"])
|
||||||
|
]
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def deactivate(self, p):
|
||||||
|
"""
|
||||||
|
Called at the end of processing for housekeeping. No need to do anything here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def activate(p, extra_network_data):
|
||||||
|
"""call activate for extra networks in extra_network_data in specified order, then call
|
||||||
|
activate for all remaining registered networks with an empty argument list"""
|
||||||
|
|
||||||
|
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||||
|
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||||
|
if extra_network is None:
|
||||||
|
print(f"Skipping unknown extra network: {extra_network_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
extra_network.activate(p, extra_network_args)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
||||||
|
|
||||||
|
for extra_network_name, extra_network in extra_network_registry.items():
|
||||||
|
args = extra_network_data.get(extra_network_name, None)
|
||||||
|
if args is not None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
extra_network.activate(p, [])
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"activating extra network {extra_network_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def deactivate(p, extra_network_data):
|
||||||
|
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||||
|
deactivate for all remaining registered networks"""
|
||||||
|
|
||||||
|
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||||
|
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||||
|
if extra_network is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
extra_network.deactivate(p)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"deactivating extra network {extra_network_name}")
|
||||||
|
|
||||||
|
for extra_network_name, extra_network in extra_network_registry.items():
|
||||||
|
args = extra_network_data.get(extra_network_name, None)
|
||||||
|
if args is not None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
extra_network.deactivate(p)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
|
||||||
|
|
||||||
|
|
||||||
|
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompt(prompt):
|
||||||
|
res = defaultdict(list)
|
||||||
|
|
||||||
|
def found(m):
|
||||||
|
name = m.group(1)
|
||||||
|
args = m.group(2)
|
||||||
|
|
||||||
|
res[name].append(ExtraNetworkParams(items=args.split(":")))
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
prompt = re.sub(re_extra_net, found, prompt)
|
||||||
|
|
||||||
|
return prompt, res
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompts(prompts):
|
||||||
|
res = []
|
||||||
|
extra_data = None
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
||||||
|
|
||||||
|
if extra_data is None:
|
||||||
|
extra_data = parsed_extra_data
|
||||||
|
|
||||||
|
res.append(updated_prompt)
|
||||||
|
|
||||||
|
return res, extra_data
|
||||||
|
|
21
modules/extra_networks_hypernet.py
Normal file
21
modules/extra_networks_hypernet.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from modules import extra_networks
|
||||||
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('hypernet')
|
||||||
|
|
||||||
|
def activate(self, p, params_list):
|
||||||
|
names = []
|
||||||
|
multipliers = []
|
||||||
|
for params in params_list:
|
||||||
|
assert len(params.items) > 0
|
||||||
|
|
||||||
|
names.append(params.items[0])
|
||||||
|
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
||||||
|
|
||||||
|
hypernetwork.load_hypernetworks(names, multipliers)
|
||||||
|
|
||||||
|
def deactivate(p, self):
|
||||||
|
pass
|
@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict):
|
|||||||
from modules import ui
|
from modules import ui
|
||||||
|
|
||||||
settings_map = {
|
settings_map = {
|
||||||
'sd_hypernetwork': 'Hypernet',
|
|
||||||
'sd_hypernetwork_strength': 'Hypernet strength',
|
|
||||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||||
'inpainting_mask_weight': 'Conditional mask weight',
|
'inpainting_mask_weight': 'Conditional mask weight',
|
||||||
'sd_model_checkpoint': 'Model hash',
|
'sd_model_checkpoint': 'Model hash',
|
||||||
@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
if "Clip skip" not in res:
|
if "Clip skip" not in res:
|
||||||
res["Clip skip"] = "1"
|
res["Clip skip"] = "1"
|
||||||
|
|
||||||
if "Hypernet strength" not in res:
|
hypernet = res.get("Hypernet", None)
|
||||||
res["Hypernet strength"] = "1"
|
if hypernet is not None:
|
||||||
|
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
||||||
if "Hypernet" in res:
|
|
||||||
hypernet_name = res["Hypernet"]
|
|
||||||
hypernet_hash = res.get("Hypernet hash", None)
|
|
||||||
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
|
|
||||||
|
|
||||||
if "Hires resize-1" not in res:
|
if "Hires resize-1" not in res:
|
||||||
res["Hires resize-1"] = 0
|
res["Hires resize-1"] = 0
|
||||||
|
@ -25,7 +25,6 @@ from statistics import stdev, mean
|
|||||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
|
||||||
activation_dict = {
|
activation_dict = {
|
||||||
"linear": torch.nn.Identity,
|
"linear": torch.nn.Identity,
|
||||||
"relu": torch.nn.ReLU,
|
"relu": torch.nn.ReLU,
|
||||||
@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.multiplier = 1.0
|
||||||
|
|
||||||
assert layer_structure is not None, "layer_structure must not be None"
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||||
@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
state_dict[to] = x
|
state_dict[to] = x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
|
return x + self.linear(x) * (self.multiplier if not self.training else 1)
|
||||||
|
|
||||||
def trainables(self):
|
def trainables(self):
|
||||||
layer_structure = []
|
layer_structure = []
|
||||||
@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
return layer_structure
|
return layer_structure
|
||||||
|
|
||||||
|
|
||||||
def apply_strength(value=None):
|
|
||||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
|
||||||
|
|
||||||
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
||||||
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
||||||
if layer_structure is None:
|
if layer_structure is None:
|
||||||
@ -192,6 +190,20 @@ class Hypernetwork:
|
|||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
param.requires_grad = mode
|
param.requires_grad = mode
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
for k, layers in self.layers.items():
|
||||||
|
for layer in layers:
|
||||||
|
layer.to(device)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_multiplier(self, multiplier):
|
||||||
|
for k, layers in self.layers.items():
|
||||||
|
for layer in layers:
|
||||||
|
layer.multiplier = multiplier
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
for k, layers in self.layers.items():
|
for k, layers in self.layers.items():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
@ -269,11 +281,13 @@ class Hypernetwork:
|
|||||||
self.optimizer_state_dict = None
|
self.optimizer_state_dict = None
|
||||||
if self.optimizer_state_dict:
|
if self.optimizer_state_dict:
|
||||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||||
print("Loaded existing optimizer from checkpoint")
|
if shared.opts.print_hypernet_extra:
|
||||||
print(f"Optimizer name is {self.optimizer_name}")
|
print("Loaded existing optimizer from checkpoint")
|
||||||
|
print(f"Optimizer name is {self.optimizer_name}")
|
||||||
else:
|
else:
|
||||||
self.optimizer_name = "AdamW"
|
self.optimizer_name = "AdamW"
|
||||||
print("No saved optimizer exists in checkpoint")
|
if shared.opts.print_hypernet_extra:
|
||||||
|
print("No saved optimizer exists in checkpoint")
|
||||||
|
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
@ -306,23 +320,43 @@ def list_hypernetworks(path):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def load_hypernetwork(filename):
|
def load_hypernetwork(name):
|
||||||
path = shared.hypernetworks.get(filename, None)
|
path = shared.hypernetworks.get(name, None)
|
||||||
# Prevent any file named "None.pt" from being loaded.
|
|
||||||
if path is not None and filename != "None":
|
|
||||||
print(f"Loading hypernetwork {filename}")
|
|
||||||
try:
|
|
||||||
shared.loaded_hypernetwork = Hypernetwork()
|
|
||||||
shared.loaded_hypernetwork.load(path)
|
|
||||||
|
|
||||||
except Exception:
|
if path is None:
|
||||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
return None
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
else:
|
|
||||||
if shared.loaded_hypernetwork is not None:
|
|
||||||
print("Unloading hypernetwork")
|
|
||||||
|
|
||||||
shared.loaded_hypernetwork = None
|
hypernetwork = Hypernetwork()
|
||||||
|
|
||||||
|
try:
|
||||||
|
hypernetwork.load(path)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return hypernetwork
|
||||||
|
|
||||||
|
|
||||||
|
def load_hypernetworks(names, multipliers=None):
|
||||||
|
already_loaded = {}
|
||||||
|
|
||||||
|
for hypernetwork in shared.loaded_hypernetworks:
|
||||||
|
if hypernetwork.name in names:
|
||||||
|
already_loaded[hypernetwork.name] = hypernetwork
|
||||||
|
|
||||||
|
shared.loaded_hypernetworks.clear()
|
||||||
|
|
||||||
|
for i, name in enumerate(names):
|
||||||
|
hypernetwork = already_loaded.get(name, None)
|
||||||
|
if hypernetwork is None:
|
||||||
|
hypernetwork = load_hypernetwork(name)
|
||||||
|
|
||||||
|
if hypernetwork is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
|
||||||
|
shared.loaded_hypernetworks.append(hypernetwork)
|
||||||
|
|
||||||
|
|
||||||
def find_closest_hypernetwork_name(search: str):
|
def find_closest_hypernetwork_name(search: str):
|
||||||
@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str):
|
|||||||
return applicable[0]
|
return applicable[0]
|
||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||||
|
|
||||||
if hypernetwork_layers is None:
|
if hypernetwork_layers is None:
|
||||||
return context, context
|
return context_k, context_v
|
||||||
|
|
||||||
if layer is not None:
|
if layer is not None:
|
||||||
layer.hyper_k = hypernetwork_layers[0]
|
layer.hyper_k = hypernetwork_layers[0]
|
||||||
layer.hyper_v = hypernetwork_layers[1]
|
layer.hyper_v = hypernetwork_layers[1]
|
||||||
|
|
||||||
context_k = hypernetwork_layers[0](context)
|
context_k = hypernetwork_layers[0](context_k)
|
||||||
context_v = hypernetwork_layers[1](context)
|
context_v = hypernetwork_layers[1](context_v)
|
||||||
|
return context_k, context_v
|
||||||
|
|
||||||
|
|
||||||
|
def apply_hypernetworks(hypernetworks, context, layer=None):
|
||||||
|
context_k = context
|
||||||
|
context_v = context
|
||||||
|
for hypernetwork in hypernetworks:
|
||||||
|
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
|
||||||
|
|
||||||
return context_k, context_v
|
return context_k, context_v
|
||||||
|
|
||||||
|
|
||||||
@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
|||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
|
context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
|
||||||
k = self.to_k(context_k)
|
k = self.to_k(context_k)
|
||||||
v = self.to_v(context_v)
|
v = self.to_v(context_v)
|
||||||
|
|
||||||
@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
template_file = template_file.path
|
template_file = template_file.path
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
shared.loaded_hypernetwork = Hypernetwork()
|
hypernetwork = Hypernetwork()
|
||||||
shared.loaded_hypernetwork.load(path)
|
hypernetwork.load(path)
|
||||||
|
shared.loaded_hypernetworks = [hypernetwork]
|
||||||
|
|
||||||
shared.state.job = "train-hypernetwork"
|
shared.state.job = "train-hypernetwork"
|
||||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||||
@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
else:
|
else:
|
||||||
images_dir = None
|
images_dir = None
|
||||||
|
|
||||||
hypernetwork = shared.loaded_hypernetwork
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
initial_step = hypernetwork.step or 0
|
initial_step = hypernetwork.step or 0
|
||||||
|
@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared
|
|||||||
not_available = ["hardswish", "multiheadattention"]
|
not_available = ["hardswish", "multiheadattention"]
|
||||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||||
|
|
||||||
@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(*args):
|
def train_hypernetwork(*args):
|
||||||
|
shared.loaded_hypernetworks = []
|
||||||
initial_hypernetwork = shared.loaded_hypernetwork
|
|
||||||
|
|
||||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||||
|
|
||||||
@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)}
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
shared.loaded_hypernetwork = initial_hypernetwork
|
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
|
@ -13,7 +13,7 @@ from skimage import exposure
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
|
||||||
"Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
|
|
||||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
|
||||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||||
@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
try:
|
try:
|
||||||
for k, v in p.override_settings.items():
|
for k, v in p.override_settings.items():
|
||||||
setattr(opts, k, v)
|
setattr(opts, k, v)
|
||||||
if k == 'sd_hypernetwork':
|
|
||||||
shared.reload_hypernetworks() # make onchange call for changing hypernet
|
|
||||||
|
|
||||||
if k == 'sd_model_checkpoint':
|
if k == 'sd_model_checkpoint':
|
||||||
sd_models.reload_model_weights() # make onchange call for changing SD model
|
sd_models.reload_model_weights()
|
||||||
|
|
||||||
if k == 'sd_vae':
|
if k == 'sd_vae':
|
||||||
sd_vae.reload_vae_weights() # make onchange call for changing VAE
|
sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
res = process_images_inner(p)
|
res = process_images_inner(p)
|
||||||
|
|
||||||
@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.override_settings_restore_afterwards:
|
if p.override_settings_restore_afterwards:
|
||||||
for k, v in stored_opts.items():
|
for k, v in stored_opts.items():
|
||||||
setattr(opts, k, v)
|
setattr(opts, k, v)
|
||||||
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
|
if k == 'sd_model_checkpoint':
|
||||||
if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
|
sd_models.reload_model_weights()
|
||||||
if k == 'sd_vae': sd_vae.reload_vae_weights()
|
|
||||||
|
if k == 'sd_vae':
|
||||||
|
sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -564,10 +561,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
cache[0] = (required_prompts, steps)
|
cache[0] = (required_prompts, steps)
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
|
p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts)
|
||||||
|
|
||||||
with torch.no_grad(), p.sd_model.ema_scope():
|
with torch.no_grad(), p.sd_model.ema_scope():
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||||
|
|
||||||
|
extra_networks.activate(p, extra_network_data)
|
||||||
|
|
||||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||||
processed = Processed(p, [], p.seed, "")
|
processed = Processed(p, [], p.seed, "")
|
||||||
file.write(processed.infotext(p, 0))
|
file.write(processed.infotext(p, 0))
|
||||||
@ -681,6 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||||
|
|
||||||
|
extra_networks.deactivate(p, extra_network_data)
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||||
|
@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||||
k_in = self.to_k(context_k)
|
k_in = self.to_k(context_k)
|
||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
del context, context_k, context_v, x
|
del context, context_k, context_v, x
|
||||||
@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||||
k_in = self.to_k(context_k)
|
k_in = self.to_k(context_k)
|
||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
|
|
||||||
@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||||
k = self.to_k(context_k) * self.scale
|
k = self.to_k(context_k) * self.scale
|
||||||
v = self.to_v(context_v)
|
v = self.to_v(context_v)
|
||||||
del context, context_k, context_v, x
|
del context, context_k, context_v, x
|
||||||
@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||||
k = self.to_k(context_k)
|
k = self.to_k(context_k)
|
||||||
v = self.to_v(context_v)
|
v = self.to_v(context_v)
|
||||||
del context, context_k, context_v, x
|
del context, context_k, context_v, x
|
||||||
@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||||
k_in = self.to_k(context_k)
|
k_in = self.to_k(context_k)
|
||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ demo = None
|
|||||||
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
|
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
default_sd_model_file = sd_model_file
|
default_sd_model_file = sd_model_file
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||||
@ -145,7 +146,7 @@ config_filename = cmd_opts.ui_settings_file
|
|||||||
|
|
||||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||||
hypernetworks = {}
|
hypernetworks = {}
|
||||||
loaded_hypernetwork = None
|
loaded_hypernetworks = []
|
||||||
|
|
||||||
|
|
||||||
def reload_hypernetworks():
|
def reload_hypernetworks():
|
||||||
@ -153,8 +154,6 @@ def reload_hypernetworks():
|
|||||||
global hypernetworks
|
global hypernetworks
|
||||||
|
|
||||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||||
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class State:
|
class State:
|
||||||
@ -399,8 +398,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
|
||||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
@ -661,3 +658,17 @@ mem_mon.start()
|
|||||||
def listfiles(dirname):
|
def listfiles(dirname):
|
||||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
|
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
|
||||||
return [file for file in filenames if os.path.isfile(file)]
|
return [file for file in filenames if os.path.isfile(file)]
|
||||||
|
|
||||||
|
|
||||||
|
def html_path(filename):
|
||||||
|
return os.path.join(script_path, "html", filename)
|
||||||
|
|
||||||
|
|
||||||
|
def html(filename):
|
||||||
|
path = html_path(filename)
|
||||||
|
|
||||||
|
if os.path.exists(path):
|
||||||
|
with open(path, encoding="utf8") as file:
|
||||||
|
return file.read()
|
||||||
|
|
||||||
|
return ""
|
||||||
|
@ -50,6 +50,7 @@ class Embedding:
|
|||||||
self.sd_checkpoint = None
|
self.sd_checkpoint = None
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
self.optimizer_state_dict = None
|
self.optimizer_state_dict = None
|
||||||
|
self.filename = None
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
embedding_data = {
|
embedding_data = {
|
||||||
@ -182,6 +183,7 @@ class EmbeddingDatabase:
|
|||||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
embedding.vectors = vec.shape[0]
|
embedding.vectors = vec.shape[0]
|
||||||
embedding.shape = vec.shape[-1]
|
embedding.shape = vec.shape[-1]
|
||||||
|
embedding.filename = path
|
||||||
|
|
||||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
|
@ -20,7 +20,7 @@ import numpy as np
|
|||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
|
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
@ -90,6 +90,7 @@ refresh_symbol = '\U0001f504' # 🔄
|
|||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
apply_style_symbol = '\U0001f4cb' # 📋
|
apply_style_symbol = '\U0001f4cb' # 📋
|
||||||
clear_prompt_symbol = '\U0001F5D1' # 🗑️
|
clear_prompt_symbol = '\U0001F5D1' # 🗑️
|
||||||
|
extra_networks_symbol = '\U0001F3B4' # 🎴
|
||||||
|
|
||||||
|
|
||||||
def plaintext_to_html(text):
|
def plaintext_to_html(text):
|
||||||
@ -324,6 +325,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
|||||||
|
|
||||||
def update_token_counter(text, steps):
|
def update_token_counter(text, steps):
|
||||||
try:
|
try:
|
||||||
|
text, _ = extra_networks.parse_prompt(text)
|
||||||
|
|
||||||
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
||||||
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
|
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
|
||||||
|
|
||||||
@ -354,10 +357,10 @@ def create_toprow(is_img2img):
|
|||||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id="roll_col"):
|
with gr.Column(scale=1, elem_id="roll_col"):
|
||||||
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
paste = ToolButton(value=paste_symbol, elem_id="paste")
|
||||||
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
|
clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
||||||
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
|
extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
|
||||||
clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
|
||||||
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||||
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||||
negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
|
negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
|
||||||
@ -395,11 +398,14 @@ def create_toprow(is_img2img):
|
|||||||
outputs=[],
|
outputs=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row(elem_id=f"{id_part}_styles_row"):
|
||||||
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
|
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
|
||||||
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
|
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
|
||||||
|
|
||||||
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button
|
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply")
|
||||||
|
save_style = ToolButton(value=save_style_symbol, elem_id="style_create")
|
||||||
|
|
||||||
|
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(*args, **kwargs):
|
def setup_progressbar(*args, **kwargs):
|
||||||
@ -616,11 +622,15 @@ def create_ui():
|
|||||||
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
|
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
|
||||||
|
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
|
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
|
||||||
|
|
||||||
|
with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
|
||||||
|
from modules import ui_extra_networks
|
||||||
|
extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
|
||||||
|
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
@ -794,14 +804,20 @@ def create_ui():
|
|||||||
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
||||||
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
|
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
|
||||||
|
|
||||||
|
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
||||||
|
|
||||||
modules.scripts.scripts_current = modules.scripts.scripts_img2img
|
modules.scripts.scripts_current = modules.scripts.scripts_img2img
|
||||||
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
|
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
|
||||||
|
|
||||||
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
|
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
|
||||||
|
|
||||||
|
with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
|
||||||
|
from modules import ui_extra_networks
|
||||||
|
extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
|
||||||
|
|
||||||
with FormRow().style(equal_height=False):
|
with FormRow().style(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
||||||
copy_image_buttons = []
|
copy_image_buttons = []
|
||||||
@ -1064,6 +1080,8 @@ def create_ui():
|
|||||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||||
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
|
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
|
||||||
|
|
||||||
|
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
||||||
|
|
||||||
img2img_paste_fields = [
|
img2img_paste_fields = [
|
||||||
(img2img_prompt, "Prompt"),
|
(img2img_prompt, "Prompt"),
|
||||||
(img2img_negative_prompt, "Negative prompt"),
|
(img2img_negative_prompt, "Negative prompt"),
|
||||||
@ -1666,10 +1684,8 @@ def create_ui():
|
|||||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||||
|
|
||||||
if os.path.exists("html/licenses.html"):
|
with gr.TabItem("Licenses"):
|
||||||
with open("html/licenses.html", encoding="utf8") as file:
|
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||||
with gr.TabItem("Licenses"):
|
|
||||||
gr.HTML(file.read(), elem_id="licenses")
|
|
||||||
|
|
||||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||||
|
|
||||||
@ -1756,11 +1772,9 @@ def create_ui():
|
|||||||
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
||||||
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||||
|
|
||||||
if os.path.exists("html/footer.html"):
|
footer = shared.html("footer.html")
|
||||||
with open("html/footer.html", encoding="utf8") as file:
|
footer = footer.format(versions=versions_html())
|
||||||
footer = file.read()
|
gr.HTML(footer, elem_id="footer")
|
||||||
footer = footer.format(versions=versions_html())
|
|
||||||
gr.HTML(footer, elem_id="footer")
|
|
||||||
|
|
||||||
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
||||||
settings_submit.click(
|
settings_submit.click(
|
||||||
|
@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent):
|
|||||||
return "button"
|
return "button"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolButtonTop(gr.Button, gr.components.FormComponent):
|
||||||
|
"""Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(variant="tool-top", **kwargs)
|
||||||
|
|
||||||
|
def get_block_name(self):
|
||||||
|
return "button"
|
||||||
|
|
||||||
|
|
||||||
class FormRow(gr.Row, gr.components.FormComponent):
|
class FormRow(gr.Row, gr.components.FormComponent):
|
||||||
"""Same as gr.Row but fits inside gradio forms"""
|
"""Same as gr.Row but fits inside gradio forms"""
|
||||||
|
|
||||||
|
149
modules/ui_extra_networks.py
Normal file
149
modules/ui_extra_networks.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
import os.path
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
import gradio as gr
|
||||||
|
import json
|
||||||
|
|
||||||
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
|
|
||||||
|
extra_pages = []
|
||||||
|
|
||||||
|
|
||||||
|
def register_page(page):
|
||||||
|
"""registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions"""
|
||||||
|
|
||||||
|
extra_pages.append(page)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraNetworksPage:
|
||||||
|
def __init__(self, title):
|
||||||
|
self.title = title
|
||||||
|
self.card_page = shared.html("extra-networks-card.html")
|
||||||
|
self.allow_negative_prompt = False
|
||||||
|
|
||||||
|
def refresh(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create_html(self, tabname):
|
||||||
|
items_html = ''
|
||||||
|
|
||||||
|
for item in self.list_items():
|
||||||
|
items_html += self.create_html_for_item(item, tabname)
|
||||||
|
|
||||||
|
if items_html == '':
|
||||||
|
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
||||||
|
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||||
|
|
||||||
|
res = "<div class='extra-network-cards'>" + items_html + "</div>"
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def list_items(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def allowed_directories_for_previews(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def create_html_for_item(self, item, tabname):
|
||||||
|
preview = item.get("preview", None)
|
||||||
|
|
||||||
|
args = {
|
||||||
|
"preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
|
||||||
|
"prompt": json.dumps(item["prompt"]),
|
||||||
|
"tabname": json.dumps(tabname),
|
||||||
|
"local_preview": json.dumps(item["local_preview"]),
|
||||||
|
"name": item["name"],
|
||||||
|
"allow_negative_prompt": "true" if self.allow_negative_prompt else "false",
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.card_page.format(**args)
|
||||||
|
|
||||||
|
|
||||||
|
def intialize():
|
||||||
|
extra_pages.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraNetworksUi:
|
||||||
|
def __init__(self):
|
||||||
|
self.pages = None
|
||||||
|
self.stored_extra_pages = None
|
||||||
|
|
||||||
|
self.button_save_preview = None
|
||||||
|
self.preview_target_filename = None
|
||||||
|
|
||||||
|
self.tabname = None
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui(container, button, tabname):
|
||||||
|
ui = ExtraNetworksUi()
|
||||||
|
ui.pages = []
|
||||||
|
ui.stored_extra_pages = extra_pages.copy()
|
||||||
|
ui.tabname = tabname
|
||||||
|
|
||||||
|
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
||||||
|
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
||||||
|
button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
|
||||||
|
|
||||||
|
for page in ui.stored_extra_pages:
|
||||||
|
with gr.Tab(page.title):
|
||||||
|
page_elem = gr.HTML(page.create_html(ui.tabname))
|
||||||
|
ui.pages.append(page_elem)
|
||||||
|
|
||||||
|
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||||
|
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||||
|
|
||||||
|
button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container])
|
||||||
|
button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container])
|
||||||
|
|
||||||
|
def refresh():
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for pg in ui.stored_extra_pages:
|
||||||
|
pg.refresh()
|
||||||
|
res.append(pg.create_html(ui.tabname))
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
||||||
|
|
||||||
|
return ui
|
||||||
|
|
||||||
|
|
||||||
|
def path_is_parent(parent_path, child_path):
|
||||||
|
parent_path = os.path.abspath(parent_path)
|
||||||
|
child_path = os.path.abspath(child_path)
|
||||||
|
|
||||||
|
return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
|
||||||
|
|
||||||
|
|
||||||
|
def setup_ui(ui, gallery):
|
||||||
|
def save_preview(index, images, filename):
|
||||||
|
if len(images) == 0:
|
||||||
|
print("There is no image in gallery to save as a preview.")
|
||||||
|
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||||
|
|
||||||
|
index = int(index)
|
||||||
|
index = 0 if index < 0 else index
|
||||||
|
index = len(images) - 1 if index >= len(images) else index
|
||||||
|
|
||||||
|
img_info = images[index if index >= 0 else 0]
|
||||||
|
image = image_from_url_text(img_info)
|
||||||
|
|
||||||
|
is_allowed = False
|
||||||
|
for extra_page in ui.stored_extra_pages:
|
||||||
|
if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
|
||||||
|
is_allowed = True
|
||||||
|
break
|
||||||
|
|
||||||
|
assert is_allowed, f'writing to {filename} is not allowed'
|
||||||
|
|
||||||
|
image.save(filename)
|
||||||
|
|
||||||
|
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||||
|
|
||||||
|
ui.button_save_preview.click(
|
||||||
|
fn=save_preview,
|
||||||
|
_js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
|
||||||
|
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
|
||||||
|
outputs=[*ui.pages]
|
||||||
|
)
|
34
modules/ui_extra_networks_hypernets.py
Normal file
34
modules/ui_extra_networks_hypernets.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from modules import shared, ui_extra_networks
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('Hypernetworks')
|
||||||
|
|
||||||
|
def refresh(self):
|
||||||
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
|
def list_items(self):
|
||||||
|
for name, path in shared.hypernetworks.items():
|
||||||
|
path, ext = os.path.splitext(path)
|
||||||
|
previews = [path + ".png", path + ".preview.png"]
|
||||||
|
|
||||||
|
preview = None
|
||||||
|
for file in previews:
|
||||||
|
if os.path.isfile(file):
|
||||||
|
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
|
||||||
|
break
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"name": name,
|
||||||
|
"filename": path,
|
||||||
|
"preview": preview,
|
||||||
|
"prompt": f"<hypernet:{name}:1.0>",
|
||||||
|
"local_preview": path + ".png",
|
||||||
|
}
|
||||||
|
|
||||||
|
def allowed_directories_for_previews(self):
|
||||||
|
return [shared.cmd_opts.hypernetwork_dir]
|
||||||
|
|
32
modules/ui_extra_networks_textual_inversion.py
Normal file
32
modules/ui_extra_networks_textual_inversion.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from modules import ui_extra_networks, sd_hijack
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('Textual Inversion')
|
||||||
|
self.allow_negative_prompt = True
|
||||||
|
|
||||||
|
def refresh(self):
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||||
|
|
||||||
|
def list_items(self):
|
||||||
|
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
|
||||||
|
path, ext = os.path.splitext(embedding.filename)
|
||||||
|
preview_file = path + ".preview.png"
|
||||||
|
|
||||||
|
preview = None
|
||||||
|
if os.path.isfile(preview_file):
|
||||||
|
preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"name": embedding.name,
|
||||||
|
"filename": embedding.filename,
|
||||||
|
"preview": preview,
|
||||||
|
"prompt": embedding.name,
|
||||||
|
"local_preview": path + ".preview.png",
|
||||||
|
}
|
||||||
|
|
||||||
|
def allowed_directories_for_previews(self):
|
||||||
|
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
13
script.js
13
script.js
@ -13,6 +13,7 @@ function get_uiCurrentTabContent() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
uiUpdateCallbacks = []
|
uiUpdateCallbacks = []
|
||||||
|
uiLoadedCallbacks = []
|
||||||
uiTabChangeCallbacks = []
|
uiTabChangeCallbacks = []
|
||||||
optionsChangedCallbacks = []
|
optionsChangedCallbacks = []
|
||||||
let uiCurrentTab = null
|
let uiCurrentTab = null
|
||||||
@ -20,6 +21,9 @@ let uiCurrentTab = null
|
|||||||
function onUiUpdate(callback){
|
function onUiUpdate(callback){
|
||||||
uiUpdateCallbacks.push(callback)
|
uiUpdateCallbacks.push(callback)
|
||||||
}
|
}
|
||||||
|
function onUiLoaded(callback){
|
||||||
|
uiLoadedCallbacks.push(callback)
|
||||||
|
}
|
||||||
function onUiTabChange(callback){
|
function onUiTabChange(callback){
|
||||||
uiTabChangeCallbacks.push(callback)
|
uiTabChangeCallbacks.push(callback)
|
||||||
}
|
}
|
||||||
@ -38,8 +42,15 @@ function executeCallbacks(queue, m) {
|
|||||||
queue.forEach(function(x){runCallback(x, m)})
|
queue.forEach(function(x){runCallback(x, m)})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var executedOnLoaded = false;
|
||||||
|
|
||||||
document.addEventListener("DOMContentLoaded", function() {
|
document.addEventListener("DOMContentLoaded", function() {
|
||||||
var mutationObserver = new MutationObserver(function(m){
|
var mutationObserver = new MutationObserver(function(m){
|
||||||
|
if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){
|
||||||
|
executedOnLoaded = true;
|
||||||
|
executeCallbacks(uiLoadedCallbacks);
|
||||||
|
}
|
||||||
|
|
||||||
executeCallbacks(uiUpdateCallbacks, m);
|
executeCallbacks(uiUpdateCallbacks, m);
|
||||||
const newTab = get_uiCurrentTab();
|
const newTab = get_uiCurrentTab();
|
||||||
if ( newTab && ( newTab !== uiCurrentTab ) ) {
|
if ( newTab && ( newTab !== uiCurrentTab ) ) {
|
||||||
@ -53,7 +64,7 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
/**
|
/**
|
||||||
* Add a ctrl+enter as a shortcut to start a generation
|
* Add a ctrl+enter as a shortcut to start a generation
|
||||||
*/
|
*/
|
||||||
document.addEventListener('keydown', function(e) {
|
document.addEventListener('keydown', function(e) {
|
||||||
var handled = false;
|
var handled = false;
|
||||||
if (e.key !== undefined) {
|
if (e.key !== undefined) {
|
||||||
if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
||||||
|
@ -11,7 +11,6 @@ import modules.scripts as scripts
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
|
from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
|
||||||
from modules.hypernetworks import hypernetwork
|
|
||||||
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -94,28 +93,6 @@ def confirm_checkpoints(p, xs):
|
|||||||
raise RuntimeError(f"Unknown checkpoint: {x}")
|
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork(p, x, xs):
|
|
||||||
if x.lower() in ["", "none"]:
|
|
||||||
name = None
|
|
||||||
else:
|
|
||||||
name = hypernetwork.find_closest_hypernetwork_name(x)
|
|
||||||
if not name:
|
|
||||||
raise RuntimeError(f"Unknown hypernetwork: {x}")
|
|
||||||
hypernetwork.load_hypernetwork(name)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork_strength(p, x, xs):
|
|
||||||
hypernetwork.apply_strength(x)
|
|
||||||
|
|
||||||
|
|
||||||
def confirm_hypernetworks(p, xs):
|
|
||||||
for x in xs:
|
|
||||||
if x.lower() in ["", "none"]:
|
|
||||||
continue
|
|
||||||
if not hypernetwork.find_closest_hypernetwork_name(x):
|
|
||||||
raise RuntimeError(f"Unknown hypernetwork: {x}")
|
|
||||||
|
|
||||||
|
|
||||||
def apply_clip_skip(p, x, xs):
|
def apply_clip_skip(p, x, xs):
|
||||||
opts.data["CLIP_stop_at_last_layers"] = x
|
opts.data["CLIP_stop_at_last_layers"] = x
|
||||||
|
|
||||||
@ -208,8 +185,6 @@ axis_options = [
|
|||||||
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
||||||
AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
||||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
|
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
|
||||||
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)),
|
|
||||||
AxisOption("Hypernet str.", float, apply_hypernetwork_strength),
|
|
||||||
AxisOption("Sigma Churn", float, apply_field("s_churn")),
|
AxisOption("Sigma Churn", float, apply_field("s_churn")),
|
||||||
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
||||||
AxisOption("Sigma max", float, apply_field("s_tmax")),
|
AxisOption("Sigma max", float, apply_field("s_tmax")),
|
||||||
@ -291,7 +266,6 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
|
|||||||
class SharedSettingsStackHelper(object):
|
class SharedSettingsStackHelper(object):
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
|
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
|
||||||
self.hypernetwork = opts.sd_hypernetwork
|
|
||||||
self.vae = opts.sd_vae
|
self.vae = opts.sd_vae
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, tb):
|
def __exit__(self, exc_type, exc_value, tb):
|
||||||
@ -299,9 +273,6 @@ class SharedSettingsStackHelper(object):
|
|||||||
modules.sd_models.reload_model_weights()
|
modules.sd_models.reload_model_weights()
|
||||||
modules.sd_vae.reload_vae_weights()
|
modules.sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
hypernetwork.load_hypernetwork(self.hypernetwork)
|
|
||||||
hypernetwork.apply_strength()
|
|
||||||
|
|
||||||
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
|
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
|
||||||
|
|
||||||
|
|
||||||
|
194
style.css
194
style.css
@ -132,13 +132,6 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
#roll_col > button {
|
#roll_col > button {
|
||||||
min-width: 2em;
|
|
||||||
min-height: 2em;
|
|
||||||
max-width: 2em;
|
|
||||||
max-height: 2em;
|
|
||||||
flex-grow: 0;
|
|
||||||
padding-left: 0.25em;
|
|
||||||
padding-right: 0.25em;
|
|
||||||
margin: 0.1em 0;
|
margin: 0.1em 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,9 +139,10 @@
|
|||||||
min-width: 0 !important;
|
min-width: 0 !important;
|
||||||
max-width: 8em !important;
|
max-width: 8em !important;
|
||||||
margin-right: 1em;
|
margin-right: 1em;
|
||||||
|
gap: 0;
|
||||||
}
|
}
|
||||||
#interrogate, #deepbooru{
|
#interrogate, #deepbooru{
|
||||||
margin: 0em 0.25em 0.9em 0.25em;
|
margin: 0em 0.25em 0.5em 0.25em;
|
||||||
min-width: 8em;
|
min-width: 8em;
|
||||||
max-width: 8em;
|
max-width: 8em;
|
||||||
}
|
}
|
||||||
@ -157,8 +151,17 @@
|
|||||||
min-width: 8em !important;
|
min-width: 8em !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#txt2img_styles_row, #img2img_styles_row{
|
||||||
|
gap: 0.25em;
|
||||||
|
margin-top: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
#txt2img_styles_row > button, #img2img_styles_row > button{
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
#txt2img_styles, #img2img_styles{
|
#txt2img_styles, #img2img_styles{
|
||||||
margin-top: 1em;
|
padding: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_styles ul, #img2img_styles ul{
|
#txt2img_styles ul, #img2img_styles ul{
|
||||||
@ -635,16 +638,20 @@ canvas[key="mask"] {
|
|||||||
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
|
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
|
||||||
}
|
}
|
||||||
|
|
||||||
.gr-button-tool{
|
.gr-button-tool, .gr-button-tool-top{
|
||||||
max-width: 2.5em;
|
max-width: 2.5em;
|
||||||
min-width: 2.5em !important;
|
min-width: 2.5em !important;
|
||||||
height: 2.4em;
|
height: 2.4em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gr-button-tool{
|
||||||
|
margin: 0.6em 0em 0.55em 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gr-button-tool-top, #settings .gr-button-tool{
|
||||||
margin: 1.6em 0.7em 0.55em 0;
|
margin: 1.6em 0.7em 0.55em 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#tab_modelmerger .gr-button-tool{
|
|
||||||
margin: 0.6em 0em 0.55em 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
#modelmerger_results_container{
|
#modelmerger_results_container{
|
||||||
margin-top: 1em;
|
margin-top: 1em;
|
||||||
@ -763,81 +770,88 @@ footer {
|
|||||||
line-height: 2.4em;
|
line-height: 2.4em;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
#txt2img_extra_networks, #img2img_extra_networks{
|
||||||
The rtl media type will only be activated by the logic in javascript/localization.js.
|
margin-top: -1em;
|
||||||
If you change anything above, you need to make sure it is RTL compliant by just running
|
|
||||||
your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/.
|
|
||||||
Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/
|
|
||||||
@media rtl {
|
|
||||||
/* this part was added manually */
|
|
||||||
:host {
|
|
||||||
direction: rtl;
|
|
||||||
}
|
|
||||||
select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress {
|
|
||||||
direction: ltr;
|
|
||||||
}
|
|
||||||
#script_list > label > select,
|
|
||||||
#x_type > label > select,
|
|
||||||
#y_type > label > select {
|
|
||||||
direction: rtl;
|
|
||||||
}
|
|
||||||
.gr-radio, .gr-checkbox{
|
|
||||||
margin-left: 0.25em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* automatically generated with few manual modifications */
|
|
||||||
.performance .time {
|
|
||||||
margin-right: unset;
|
|
||||||
margin-left: 0;
|
|
||||||
}
|
|
||||||
.justify-center.overflow-x-scroll {
|
|
||||||
justify-content: right;
|
|
||||||
}
|
|
||||||
.justify-center.overflow-x-scroll button:first-of-type {
|
|
||||||
margin-left: unset;
|
|
||||||
margin-right: auto;
|
|
||||||
}
|
|
||||||
.justify-center.overflow-x-scroll button:last-of-type {
|
|
||||||
margin-right: unset;
|
|
||||||
margin-left: auto;
|
|
||||||
}
|
|
||||||
#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
|
|
||||||
margin-right: unset;
|
|
||||||
margin-left: 8em;
|
|
||||||
}
|
|
||||||
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
|
|
||||||
right: unset;
|
|
||||||
left: 0;
|
|
||||||
}
|
|
||||||
.progressDiv .progress{
|
|
||||||
padding: 0 0 0 8px;
|
|
||||||
text-align: left;
|
|
||||||
}
|
|
||||||
#lightboxModal{
|
|
||||||
left: unset;
|
|
||||||
right: 0;
|
|
||||||
}
|
|
||||||
.modalPrev, .modalNext{
|
|
||||||
border-radius: 3px 0 0 3px;
|
|
||||||
}
|
|
||||||
.modalNext {
|
|
||||||
right: unset;
|
|
||||||
left: 0;
|
|
||||||
border-radius: 0 3px 3px 0;
|
|
||||||
}
|
|
||||||
#imageARPreview{
|
|
||||||
left:unset;
|
|
||||||
right:0px;
|
|
||||||
}
|
|
||||||
#txt2img_skip, #img2img_skip{
|
|
||||||
right: unset;
|
|
||||||
left: 0px;
|
|
||||||
}
|
|
||||||
#context-menu{
|
|
||||||
box-shadow:-1px 1px 2px #CE6400;
|
|
||||||
}
|
|
||||||
.gr-box > div > div > input.gr-text-input{
|
|
||||||
right: unset;
|
|
||||||
left: 0.5em;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.extra-networks > div > [id *= '_extra_']{
|
||||||
|
margin: 0.3em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .nocards{
|
||||||
|
margin: 1.25em 0.5em 0.5em 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .nocards h1{
|
||||||
|
font-size: 1.5em;
|
||||||
|
margin-bottom: 1em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .nocards li{
|
||||||
|
margin-left: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card{
|
||||||
|
display: inline-block;
|
||||||
|
margin: 0.5em;
|
||||||
|
width: 16em;
|
||||||
|
height: 24em;
|
||||||
|
box-shadow: 0 0 5px rgba(128, 128, 128, 0.5);
|
||||||
|
border-radius: 0.2em;
|
||||||
|
position: relative;
|
||||||
|
|
||||||
|
background-size: auto 100%;
|
||||||
|
background-position: center;
|
||||||
|
overflow: hidden;
|
||||||
|
cursor: pointer;
|
||||||
|
|
||||||
|
background-image: url('./file=html/card-no-preview.png')
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card:hover{
|
||||||
|
box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35);
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card .actions .additional{
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card .actions{
|
||||||
|
position: absolute;
|
||||||
|
bottom: 0;
|
||||||
|
left: 0;
|
||||||
|
right: 0;
|
||||||
|
padding: 0.5em;
|
||||||
|
color: white;
|
||||||
|
background: rgba(0,0,0,0.5);
|
||||||
|
box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5);
|
||||||
|
text-shadow: 0 0 0.2em black;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card .actions:hover{
|
||||||
|
box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card .actions .name{
|
||||||
|
font-size: 1.7em;
|
||||||
|
font-weight: bold;
|
||||||
|
line-break: anywhere;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card .actions:hover .additional{
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card ul{
|
||||||
|
margin: 0.25em 0 0.75em 0.25em;
|
||||||
|
cursor: unset;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card ul a{
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card ul a:hover{
|
||||||
|
color: red;
|
||||||
|
}
|
||||||
|
|
||||||
|
26
webui.py
26
webui.py
@ -9,16 +9,18 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
|
|
||||||
from modules import import_hook, errors
|
from modules import import_hook, errors, extra_networks
|
||||||
|
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
||||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||||
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||||
|
|
||||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
|
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
|
||||||
import modules.codeformer_model as codeformer
|
import modules.codeformer_model as codeformer
|
||||||
import modules.extras
|
import modules.extras
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
@ -84,10 +86,17 @@ def initialize():
|
|||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
|
|
||||||
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
|
||||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||||
|
|
||||||
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
|
ui_extra_networks.intialize()
|
||||||
|
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||||
|
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||||
|
|
||||||
|
extra_networks.initialize()
|
||||||
|
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||||
|
|
||||||
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
|
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -209,6 +218,15 @@ def webui():
|
|||||||
|
|
||||||
modules.sd_models.list_models()
|
modules.sd_models.list_models()
|
||||||
|
|
||||||
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
|
ui_extra_networks.intialize()
|
||||||
|
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||||
|
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||||
|
|
||||||
|
extra_networks.initialize()
|
||||||
|
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if cmd_opts.nowebui:
|
if cmd_opts.nowebui:
|
||||||
|
Loading…
Reference in New Issue
Block a user