fix missing infotext cased by conda cache

some generation params such as TI hashes or Emphasis is added in sd_hijack / sd_hijack_clip
if conda are fetche from cache sd_hijack_clip will not be executed and it won't have a chance to to add generation params

the generation params will also be missing if in non low-vram mode because the hijack.extra_generation_params was never read after calculate_hr_conds
This commit is contained in:
w-e-w 2024-11-23 17:31:01 +09:00
parent 023454b49e
commit 025080218f
4 changed files with 102 additions and 13 deletions

View File

@ -16,7 +16,7 @@ from skimage import exposure
from typing import Any
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling, util
from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
@ -457,6 +457,14 @@ class StableDiffusionProcessing:
opts.emphasis,
)
def apply_generation_params_states(self, generation_params_states):
"""add and apply generation_params_states to self.extra_generation_params"""
for key, value in generation_params_states.items():
if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
self.extra_generation_params[key] = current_value + value
else:
self.extra_generation_params[key] = value
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
@ -480,6 +488,10 @@ class StableDiffusionProcessing:
for cache in caches:
if cache[0] is not None and cached_params == cache[0]:
if len(cache) == 3:
generation_params_states, cached_cached_params = cache[2]
if cached_params == cached_cached_params:
self.apply_generation_params_states(generation_params_states)
return cache[1]
cache = caches[0]
@ -487,6 +499,13 @@ class StableDiffusionProcessing:
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
generation_params_states = model_hijack.extract_generation_params_states()
self.apply_generation_params_states(generation_params_states)
if len(cache) == 2:
cache.append((generation_params_states, cached_params))
else:
cache[2] = (generation_params_states, cached_params)
cache[0] = cached_params
return cache[1]
@ -502,6 +521,8 @@ class StableDiffusionProcessing:
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
self.extra_generation_params.update(model_hijack.extra_generation_params)
def get_conds(self):
return self.c, self.uc
@ -801,10 +822,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
for key, value in generation_params.items():
try:
if isinstance(value, list):
generation_params[key] = value[index]
elif callable(value):
if callable(value):
generation_params[key] = value(**locals())
elif isinstance(value, list):
generation_params[key] = value[index]
except Exception:
errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
generation_params[key] = None
@ -965,8 +986,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_conds()
p.extra_generation_params.update(model_hijack.extra_generation_params)
# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
# Example: a wildcard processed by process_batch sets an extra model
@ -1513,6 +1532,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
self.extra_generation_params.update(model_hijack.extra_generation_params)
def setup_conds(self):
if self.is_hr_pass:
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model

View File

@ -2,7 +2,7 @@ import torch
from torch.nn.functional import silu
from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches, util
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
@ -321,6 +321,14 @@ class StableDiffusionModelHijack:
self.comments = []
self.extra_generation_params = {}
def extract_generation_params_states(self):
"""Extracts GenerationParametersList so that they can be cached and restored later"""
states = {}
for key in list(self.extra_generation_params):
if isinstance(self.extra_generation_params[key], util.GenerationParametersList):
states[key] = self.extra_generation_params.pop(key)
return states
def get_prompt_lengths(self, text):
if self.clip is None:
return "-", "-"

View File

@ -3,7 +3,7 @@ from collections import namedtuple
import torch
from modules import prompt_parser, devices, sd_hijack, sd_emphasis
from modules import prompt_parser, devices, sd_hijack, sd_emphasis, util
from modules.shared import opts
@ -27,6 +27,30 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
class EmphasisMode(util.GenerationParametersList):
def __init__(self, emphasis_mode:str = None):
super().__init__()
self.emphasis_mode = emphasis_mode
def __call__(self, *args, **kwargs):
return self.emphasis_mode
def __add__(self, other):
if isinstance(other, EmphasisMode):
return self if self.emphasis_mode else other
elif isinstance(other, str):
return self.__str__() + other
return NotImplemented
def __radd__(self, other):
if isinstance(other, str):
return other + self.__str__()
return NotImplemented
def __str__(self):
return self.emphasis_mode if self.emphasis_mode else ''
class TextConditionalModel(torch.nn.Module):
def __init__(self):
super().__init__()
@ -238,12 +262,10 @@ class TextConditionalModel(torch.nn.Module):
hashes.append(f"{name}: {shorthash}")
if hashes:
if self.hijack.extra_generation_params.get("TI hashes"):
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
self.hijack.extra_generation_params["TI hashes"] = util.GenerationParametersList(hashes)
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x):
self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(opts.emphasis)
if self.return_pooled:
return torch.hstack(zs), zs[0].pooled

View File

@ -288,3 +288,41 @@ def compare_sha256(file_path: str, hash_prefix: str) -> bool:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
class GenerationParametersList(list):
"""A special object used in sd_hijack.StableDiffusionModelHijack for setting extra_generation_params
due to StableDiffusionProcessing.get_conds_with_caching
extra_generation_params set in StableDiffusionModelHijack will be lost when cached is used
When an extra_generation_params is set in StableDiffusionModelHijack using this object,
the params will be extracted by StableDiffusionModelHijack.extract_generation_params_states
the extracted params will be cached in StableDiffusionProcessing.get_conds_with_caching
and applyed to StableDiffusionProcessing.extra_generation_params by StableDiffusionProcessing.apply_generation_params_states
Example see modules.sd_hijack_clip.TextConditionalModel.hijack.extra_generation_params 'TI hashes' 'Emphasis'
Depending on the use case the methods can be overwritten.
In general __call__ method should return str or None, as normally it's called in modules.processing.create_infotext.
When called by create_infotext it will access to the locals() of the caller,
if return str, the value will be written to infotext, if return None will be ignored.
"""
def __call__(self, *args, **kwargs):
return ', '.join(sorted(set(self), key=natural_sort_key))
def __add__(self, other):
if isinstance(other, GenerationParametersList):
return self.__class__([*self, *other])
elif isinstance(other, str):
return self.__str__() + other
return NotImplemented
def __radd__(self, other):
if isinstance(other, str):
return other + self.__str__()
return NotImplemented
def __str__(self):
return self.__call__()