mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui
synced 2025-01-07 07:46:58 +08:00
support scheduler selection in hires fix
This commit is contained in:
parent
755d2cb2e5
commit
9aa9e980a9
@ -314,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
if "Hires sampler" not in res:
|
||||
res["Hires sampler"] = "Use same sampler"
|
||||
|
||||
if "Hires schedule type" not in res:
|
||||
res["Hires schedule type"] = "Use same scheduler"
|
||||
|
||||
if "Hires checkpoint" not in res:
|
||||
res["Hires checkpoint"] = "Use same checkpoint"
|
||||
|
||||
|
@ -1115,6 +1115,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
hr_resize_y: int = 0
|
||||
hr_checkpoint_name: str = None
|
||||
hr_sampler_name: str = None
|
||||
hr_scheduler: str = None
|
||||
hr_prompt: str = ''
|
||||
hr_negative_prompt: str = ''
|
||||
force_task_id: str = None
|
||||
@ -1203,6 +1204,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
||||
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
||||
|
||||
self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py
|
||||
|
||||
if self.hr_scheduler is None:
|
||||
self.hr_scheduler = self.scheduler
|
||||
|
||||
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||
if self.enable_hr and self.latent_scale_mode is None:
|
||||
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
||||
|
@ -1,44 +1,10 @@
|
||||
import gradio as gr
|
||||
import functools
|
||||
|
||||
from modules import scripts, sd_samplers, sd_schedulers, shared
|
||||
from modules.infotext_utils import PasteField
|
||||
from modules.ui_components import FormRow, FormGroup
|
||||
|
||||
|
||||
def get_sampler_from_infotext(d: dict):
|
||||
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
|
||||
|
||||
|
||||
def get_scheduler_from_infotext(d: dict):
|
||||
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_sampler_and_scheduler(sampler_name, scheduler_name):
|
||||
default_sampler = sd_samplers.samplers[0]
|
||||
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
|
||||
|
||||
name = sampler_name or default_sampler.name
|
||||
|
||||
for scheduler in sd_schedulers.schedulers:
|
||||
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
|
||||
|
||||
for name_option in name_options:
|
||||
if name.endswith(" " + name_option):
|
||||
found_scheduler = scheduler
|
||||
name = name[0:-(len(name_option) + 1)]
|
||||
break
|
||||
|
||||
sampler = sd_samplers.all_samplers_map.get(name, default_sampler)
|
||||
|
||||
# revert back to Automatic if it's the default scheduler for the selected sampler
|
||||
if sampler.options.get('scheduler', None) == found_scheduler.name:
|
||||
found_scheduler = sd_schedulers.schedulers[0]
|
||||
|
||||
return sampler.name, found_scheduler.label
|
||||
|
||||
|
||||
class ScriptSampler(scripts.ScriptBuiltinUI):
|
||||
section = "sampler"
|
||||
|
||||
@ -67,8 +33,8 @@ class ScriptSampler(scripts.ScriptBuiltinUI):
|
||||
|
||||
self.infotext_fields = [
|
||||
PasteField(self.steps, "Steps", api="steps"),
|
||||
PasteField(self.sampler_name, get_sampler_from_infotext, api="sampler_name"),
|
||||
PasteField(self.scheduler, get_scheduler_from_infotext, api="scheduler"),
|
||||
PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"),
|
||||
PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"),
|
||||
]
|
||||
|
||||
return self.steps, self.sampler_name, self.scheduler
|
||||
|
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common
|
||||
import functools
|
||||
|
||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
samples_to_image_grid = sd_samplers_common.samples_to_image_grid
|
||||
@ -64,4 +66,60 @@ def visible_samplers():
|
||||
return [x for x in samplers if x.name not in samplers_hidden]
|
||||
|
||||
|
||||
def get_sampler_from_infotext(d: dict):
|
||||
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
|
||||
|
||||
|
||||
def get_scheduler_from_infotext(d: dict):
|
||||
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
|
||||
|
||||
|
||||
def get_hr_sampler_and_scheduler(d: dict):
|
||||
hr_sampler = d.get("Hires sampler", "Use same sampler")
|
||||
sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler
|
||||
|
||||
hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
|
||||
scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler
|
||||
|
||||
sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)
|
||||
|
||||
sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
|
||||
scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"
|
||||
|
||||
return sampler, scheduler
|
||||
|
||||
|
||||
def get_hr_sampler_from_infotext(d: dict):
|
||||
return get_hr_sampler_and_scheduler(d)[0]
|
||||
|
||||
|
||||
def get_hr_scheduler_from_infotext(d: dict):
|
||||
return get_hr_sampler_and_scheduler(d)[1]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_sampler_and_scheduler(sampler_name, scheduler_name):
|
||||
default_sampler = samplers[0]
|
||||
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
|
||||
|
||||
name = sampler_name or default_sampler.name
|
||||
|
||||
for scheduler in sd_schedulers.schedulers:
|
||||
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
|
||||
|
||||
for name_option in name_options:
|
||||
if name.endswith(" " + name_option):
|
||||
found_scheduler = scheduler
|
||||
name = name[0:-(len(name_option) + 1)]
|
||||
break
|
||||
|
||||
sampler = all_samplers_map.get(name, default_sampler)
|
||||
|
||||
# revert back to Automatic if it's the default scheduler for the selected sampler
|
||||
if sampler.options.get('scheduler', None) == found_scheduler.name:
|
||||
found_scheduler = sd_schedulers.schedulers[0]
|
||||
|
||||
return sampler.name, found_scheduler.label
|
||||
|
||||
|
||||
set_samplers()
|
||||
|
@ -79,7 +79,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
|
||||
steps += 1 if discard_next_to_last_sigma else 0
|
||||
|
||||
scheduler_name = p.scheduler or 'Automatic'
|
||||
scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic'
|
||||
if scheduler_name == 'Automatic':
|
||||
scheduler_name = self.config.options.get('scheduler', None)
|
||||
|
||||
@ -95,8 +95,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
else:
|
||||
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
|
||||
|
||||
if scheduler.label != 'Automatic':
|
||||
if scheduler.label != 'Automatic' and not p.is_hr_pass:
|
||||
p.extra_generation_params["Schedule type"] = scheduler.label
|
||||
elif scheduler.label != p.extra_generation_params.get("Schedule type"):
|
||||
p.extra_generation_params["Hires schedule type"] = scheduler.label
|
||||
|
||||
if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
|
||||
sigmas_kwargs['sigma_min'] = opts.sigma_min
|
||||
|
@ -11,7 +11,7 @@ from PIL import Image
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
||||
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
if force_enable_hr:
|
||||
@ -38,6 +38,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
|
||||
hr_resize_y=hr_resize_y,
|
||||
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
||||
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
|
||||
hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler,
|
||||
hr_prompt=hr_prompt,
|
||||
hr_negative_prompt=hr_negative_prompt,
|
||||
override_settings=override_settings,
|
||||
|
@ -322,10 +322,11 @@ def create_ui():
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
||||
|
||||
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||
hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||
|
||||
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
||||
hr_sampler_name = gr.Dropdown(label='Sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
||||
hr_scheduler = gr.Dropdown(label='Schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler")
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||
with gr.Column(scale=80):
|
||||
@ -394,6 +395,7 @@ def create_ui():
|
||||
hr_resize_y,
|
||||
hr_checkpoint_name,
|
||||
hr_sampler_name,
|
||||
hr_scheduler,
|
||||
hr_prompt,
|
||||
hr_negative_prompt,
|
||||
override_settings,
|
||||
@ -456,8 +458,9 @@ def create_ui():
|
||||
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
|
||||
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
|
||||
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
|
||||
PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
|
||||
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
||||
PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"),
|
||||
PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"),
|
||||
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()),
|
||||
PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
|
||||
PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
|
||||
PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
||||
|
Loading…
Reference in New Issue
Block a user