diff --git a/modules/processing.py b/modules/processing.py index 92c3582cc..0c747601f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from skimage import exposure from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling, util from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -457,6 +457,14 @@ class StableDiffusionProcessing: opts.emphasis, ) + def apply_generation_params_states(self, generation_params_states): + """add and apply generation_params_states to self.extra_generation_params""" + for key, value in generation_params_states.items(): + if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList): + self.extra_generation_params[key] = current_value + value + else: + self.extra_generation_params[key] = value + def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) @@ -480,6 +488,10 @@ class StableDiffusionProcessing: for cache in caches: if cache[0] is not None and cached_params == cache[0]: + if len(cache) == 3: + generation_params_states, cached_cached_params = cache[2] + if cached_params == cached_cached_params: + self.apply_generation_params_states(generation_params_states) return cache[1] cache = caches[0] @@ -487,6 +499,13 @@ class StableDiffusionProcessing: with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) + generation_params_states = model_hijack.extract_generation_params_states() + self.apply_generation_params_states(generation_params_states) + if len(cache) == 2: + cache.append((generation_params_states, cached_params)) + else: + cache[2] = (generation_params_states, cached_params) + cache[0] = cached_params return cache[1] @@ -502,6 +521,8 @@ class StableDiffusionProcessing: self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data) + self.extra_generation_params.update(model_hijack.extra_generation_params) + def get_conds(self): return self.c, self.uc @@ -801,10 +822,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter for key, value in generation_params.items(): try: - if isinstance(value, list): - generation_params[key] = value[index] - elif callable(value): + if callable(value): generation_params[key] = value(**locals()) + elif isinstance(value, list): + generation_params[key] = value[index] except Exception: errors.report(f'Error creating infotext for key "{key}"', exc_info=True) generation_params[key] = None @@ -965,8 +986,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.setup_conds() - p.extra_generation_params.update(model_hijack.extra_generation_params) - # params.txt should be saved after scripts.process_batch, since the # infotext could be modified by that callback # Example: a wildcard processed by process_batch sets an extra model @@ -1513,6 +1532,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps) self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps) + self.extra_generation_params.update(model_hijack.extra_generation_params) + def setup_conds(self): if self.is_hr_pass: # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 0de830541..4ac22ec53 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -2,7 +2,7 @@ import torch from torch.nn.functional import silu from types import MethodType -from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches +from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches, util from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 @@ -321,6 +321,14 @@ class StableDiffusionModelHijack: self.comments = [] self.extra_generation_params = {} + def extract_generation_params_states(self): + """Extracts GenerationParametersList so that they can be cached and restored later""" + states = {} + for key in list(self.extra_generation_params): + if isinstance(self.extra_generation_params[key], util.GenerationParametersList): + states[key] = self.extra_generation_params.pop(key) + return states + def get_prompt_lengths(self, text): if self.clip is None: return "-", "-" diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index a479148fc..62c632f82 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -3,7 +3,7 @@ from collections import namedtuple import torch -from modules import prompt_parser, devices, sd_hijack, sd_emphasis +from modules import prompt_parser, devices, sd_hijack, sd_emphasis, util from modules.shared import opts @@ -27,6 +27,30 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" +class EmphasisMode(util.GenerationParametersList): + def __init__(self, emphasis_mode:str = None): + super().__init__() + self.emphasis_mode = emphasis_mode + + def __call__(self, *args, **kwargs): + return self.emphasis_mode + + def __add__(self, other): + if isinstance(other, EmphasisMode): + return self if self.emphasis_mode else other + elif isinstance(other, str): + return self.__str__() + other + return NotImplemented + + def __radd__(self, other): + if isinstance(other, str): + return other + self.__str__() + return NotImplemented + + def __str__(self): + return self.emphasis_mode if self.emphasis_mode else '' + + class TextConditionalModel(torch.nn.Module): def __init__(self): super().__init__() @@ -238,12 +262,10 @@ class TextConditionalModel(torch.nn.Module): hashes.append(f"{name}: {shorthash}") if hashes: - if self.hijack.extra_generation_params.get("TI hashes"): - hashes.append(self.hijack.extra_generation_params.get("TI hashes")) - self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) + self.hijack.extra_generation_params["TI hashes"] = util.GenerationParametersList(hashes) - if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": - self.hijack.extra_generation_params["Emphasis"] = opts.emphasis + if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x): + self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(opts.emphasis) if self.return_pooled: return torch.hstack(zs), zs[0].pooled diff --git a/modules/util.py b/modules/util.py index baeba2fa2..1aef93cfe 100644 --- a/modules/util.py +++ b/modules/util.py @@ -288,3 +288,41 @@ def compare_sha256(file_path: str, hash_prefix: str) -> bool: for chunk in iter(lambda: f.read(blksize), b""): hash_sha256.update(chunk) return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower()) + + +class GenerationParametersList(list): + """A special object used in sd_hijack.StableDiffusionModelHijack for setting extra_generation_params + due to StableDiffusionProcessing.get_conds_with_caching + extra_generation_params set in StableDiffusionModelHijack will be lost when cached is used + + When an extra_generation_params is set in StableDiffusionModelHijack using this object, + the params will be extracted by StableDiffusionModelHijack.extract_generation_params_states + the extracted params will be cached in StableDiffusionProcessing.get_conds_with_caching + and applyed to StableDiffusionProcessing.extra_generation_params by StableDiffusionProcessing.apply_generation_params_states + + Example see modules.sd_hijack_clip.TextConditionalModel.hijack.extra_generation_params 'TI hashes' 'Emphasis' + + Depending on the use case the methods can be overwritten. + In general __call__ method should return str or None, as normally it's called in modules.processing.create_infotext. + When called by create_infotext it will access to the locals() of the caller, + if return str, the value will be written to infotext, if return None will be ignored. + """ + + def __call__(self, *args, **kwargs): + return ', '.join(sorted(set(self), key=natural_sort_key)) + + def __add__(self, other): + if isinstance(other, GenerationParametersList): + return self.__class__([*self, *other]) + elif isinstance(other, str): + return self.__str__() + other + return NotImplemented + + def __radd__(self, other): + if isinstance(other, str): + return other + self.__str__() + return NotImplemented + + def __str__(self): + return self.__call__() +