clear GenerationParametersList before batch

clears any generation parameters that are with the attribute to_be_clear_before_batch = True
prevent buildup of some parameters
This commit is contained in:
w-e-w 2024-11-24 20:07:00 +09:00
parent 025080218f
commit ac8c05398b
2 changed files with 18 additions and 3 deletions

View File

@ -457,7 +457,7 @@ class StableDiffusionProcessing:
opts.emphasis, opts.emphasis,
) )
def apply_generation_params_states(self, generation_params_states): def apply_generation_params_list(self, generation_params_states):
"""add and apply generation_params_states to self.extra_generation_params""" """add and apply generation_params_states to self.extra_generation_params"""
for key, value in generation_params_states.items(): for key, value in generation_params_states.items():
if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList): if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
@ -465,6 +465,12 @@ class StableDiffusionProcessing:
else: else:
self.extra_generation_params[key] = value self.extra_generation_params[key] = value
def clear_marked_generation_params(self):
"""clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
for key, value in list(self.extra_generation_params.items()):
if getattr(value, 'to_be_clear_before_batch', False):
self.extra_generation_params.pop(key)
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None): def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
""" """
Returns the result of calling function(shared.sd_model, required_prompts, steps) Returns the result of calling function(shared.sd_model, required_prompts, steps)
@ -491,7 +497,7 @@ class StableDiffusionProcessing:
if len(cache) == 3: if len(cache) == 3:
generation_params_states, cached_cached_params = cache[2] generation_params_states, cached_cached_params = cache[2]
if cached_params == cached_cached_params: if cached_params == cached_cached_params:
self.apply_generation_params_states(generation_params_states) self.apply_generation_params_list(generation_params_states)
return cache[1] return cache[1]
cache = caches[0] cache = caches[0]
@ -500,7 +506,7 @@ class StableDiffusionProcessing:
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
generation_params_states = model_hijack.extract_generation_params_states() generation_params_states = model_hijack.extract_generation_params_states()
self.apply_generation_params_states(generation_params_states) self.apply_generation_params_list(generation_params_states)
if len(cache) == 2: if len(cache) == 2:
cache.append((generation_params_states, cached_params)) cache.append((generation_params_states, cached_params))
else: else:
@ -959,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted or state.stopping_generation: if state.interrupted or state.stopping_generation:
break break
p.clear_marked_generation_params() # clean up some generation params are tagged to be cleared before batch
sd_models.reload_model_weights() # model can be changed for example by refiner sd_models.reload_model_weights() # model can be changed for example by refiner
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]

View File

@ -308,9 +308,17 @@ class GenerationParametersList(list):
if return str, the value will be written to infotext, if return None will be ignored. if return str, the value will be written to infotext, if return None will be ignored.
""" """
def __init__(self, *args, to_be_clear_before_batch=True, **kwargs):
super().__init__(*args, **kwargs)
self._to_be_clear_before_batch = to_be_clear_before_batch
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return ', '.join(sorted(set(self), key=natural_sort_key)) return ', '.join(sorted(set(self), key=natural_sort_key))
@property
def to_be_clear_before_batch(self):
return self._to_be_clear_before_batch
def __add__(self, other): def __add__(self, other):
if isinstance(other, GenerationParametersList): if isinstance(other, GenerationParametersList):
return self.__class__([*self, *other]) return self.__class__([*self, *other])