refa: optimize request params for nai4 (#269)

This commit is contained in:
Guicheng Liu 2024-12-25 22:21:24 +08:00 committed by GitHub
parent d84328c1a6
commit 4d09b73fcd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 5 deletions

View File

@ -226,6 +226,7 @@ interface ParamConfig {
smea?: boolean smea?: boolean
smeaDyn?: boolean smeaDyn?: boolean
scheduler?: string scheduler?: string
rescale?: Computed<number>
decrisper?: boolean decrisper?: boolean
upscaler?: string upscaler?: string
restoreFaces?: boolean restoreFaces?: boolean
@ -236,6 +237,7 @@ interface ParamConfig {
imageSteps?: Computed<number> imageSteps?: Computed<number>
maxSteps?: Computed<number> maxSteps?: Computed<number>
strength?: Computed<number> strength?: Computed<number>
noise?: Computed<number>
resolution?: Computed<Orient | Size> resolution?: Computed<Orient | Size>
maxResolution?: Computed<number> maxResolution?: Computed<number>
} }
@ -402,8 +404,9 @@ export const Config = Schema.intersect([
}), }),
Schema.object({ Schema.object({
model: Schema.const('nai-v4-curated-preview'), model: Schema.const('nai-v4-curated-preview'),
sampler: sampler.createSchema(sampler.nai4), sampler: sampler.createSchema(sampler.nai4).default('k_euler_a'),
scheduler: Schema.union(scheduler.nai4).description('默认的调度器。').default('karras'), scheduler: Schema.union(scheduler.nai4).description('默认的调度器。').default('karras'),
rescale: Schema.computed(Schema.number(), options).min(0).max(1).description('输入服从度调整规模。').default(0),
}), }),
Schema.object({ sampler: sampler.createSchema(sampler.nai) }), Schema.object({ sampler: sampler.createSchema(sampler.nai) }),
]), ]),
@ -417,6 +420,7 @@ export const Config = Schema.intersect([
imageSteps: Schema.computed(Schema.natural(), options).description('以图生图时默认的迭代步数。').default(50), imageSteps: Schema.computed(Schema.natural(), options).description('以图生图时默认的迭代步数。').default(50),
maxSteps: Schema.computed(Schema.natural(), options).description('允许的最大迭代步数。').default(64), maxSteps: Schema.computed(Schema.natural(), options).description('允许的最大迭代步数。').default(64),
strength: Schema.computed(Schema.number(), options).min(0).max(1).description('默认的重绘强度。').default(0.7), strength: Schema.computed(Schema.number(), options).min(0).max(1).description('默认的重绘强度。').default(0.7),
noise: Schema.computed(Schema.number(), options).min(0).max(1).description('默认的重绘添加噪声强度。').default(0.2),
resolution: Schema.computed(Schema.union([ resolution: Schema.computed(Schema.union([
Schema.const('portrait').description('肖像 (832x2326)'), Schema.const('portrait').description('肖像 (832x2326)'),
Schema.const('landscape').description('风景 (1216x832)'), Schema.const('landscape').description('风景 (1216x832)'),

View File

@ -261,7 +261,7 @@ export function apply(ctx: Context, config: Config) {
Object.assign(parameters, { Object.assign(parameters, {
height: options.resolution.height, height: options.resolution.height,
width: options.resolution.width, width: options.resolution.width,
noise: options.noise ?? 0.2, noise: options.noise ?? session.resolve(config.noise),
strength: options.strength ?? session.resolve(config.strength), strength: options.strength ?? session.resolve(config.strength),
}) })
} }
@ -334,7 +334,9 @@ export function apply(ctx: Context, config: Config) {
} }
parameters.dynamic_thresholding = options.decrisper ?? config.decrisper parameters.dynamic_thresholding = options.decrisper ?? config.decrisper
if (model === 'nai-diffusion-3' || model === 'nai-diffusion-4-curated-preview') { if (model === 'nai-diffusion-3' || model === 'nai-diffusion-4-curated-preview') {
parameters.params_version = 3
parameters.legacy = false parameters.legacy = false
parameters.legacy_v3_extend = false
parameters.noise_schedule = options.scheduler ?? config.scheduler parameters.noise_schedule = options.scheduler ?? config.scheduler
// Max scale for nai-v3 is 10, but not 20. // Max scale for nai-v3 is 10, but not 20.
// If the given value is greater than 10, // If the given value is greater than 10,
@ -343,7 +345,6 @@ export function apply(ctx: Context, config: Config) {
parameters.scale = parameters.scale / 2 parameters.scale = parameters.scale / 2
} }
if (model === 'nai-diffusion-3') { if (model === 'nai-diffusion-3') {
parameters.legacy_v3_extend = false
parameters.sm_dyn = options.smeaDyn ?? config.smeaDyn parameters.sm_dyn = options.smeaDyn ?? config.smeaDyn
parameters.sm = (options.smea ?? config.smea) || parameters.sm_dyn parameters.sm = (options.smea ?? config.smea) || parameters.sm_dyn
if (['k_euler_ancestral', 'k_dpmpp_2s_ancestral'].includes(parameters.sampler) if (['k_euler_ancestral', 'k_dpmpp_2s_ancestral'].includes(parameters.sampler)
@ -357,8 +358,17 @@ export function apply(ctx: Context, config: Config) {
} }
} }
if (model === 'nai-diffusion-4-curated-preview') { if (model === 'nai-diffusion-4-curated-preview') {
parameters.use_coords = false // unknown parameters.add_original_image = true // unknown
parameters.cfg_rescale = session.resolve(config.rescale)
parameters.characterPrompts = [] satisfies NovelAI.V4CharacterPrompt[] parameters.characterPrompts = [] satisfies NovelAI.V4CharacterPrompt[]
parameters.controlnet_strength = 1 // unknown
parameters.deliberate_euler_ancestral_bug = false // unknown
parameters.prefer_brownian = true // unknown
parameters.reference_image_multiple = [] // unknown
parameters.reference_information_extracted_multiple = [] // unknown
parameters.reference_strength_multiple = [] // unknown
parameters.skip_cfg_above_sigma = null // unknown
parameters.use_coords = false // unknown
parameters.v4_prompt = { parameters.v4_prompt = {
caption: { caption: {
base_caption: prompt, base_caption: prompt,
@ -479,7 +489,7 @@ export function apply(ctx: Context, config: Config) {
prompt[nodeId].inputs.steps = parameters.steps prompt[nodeId].inputs.steps = parameters.steps
prompt[nodeId].inputs.cfg = parameters.scale prompt[nodeId].inputs.cfg = parameters.scale
prompt[nodeId].inputs.sampler_name = options.sampler prompt[nodeId].inputs.sampler_name = options.sampler
prompt[nodeId].inputs.denoise = options.strength ?? config.strength prompt[nodeId].inputs.denoise = options.strength ?? session.resolve(config.strength)
prompt[nodeId].inputs.scheduler = options.scheduler ?? config.scheduler prompt[nodeId].inputs.scheduler = options.scheduler ?? config.scheduler
const positiveNodeId = prompt[nodeId].inputs.positive[0] const positiveNodeId = prompt[nodeId].inputs.positive[0]
const negativeeNodeId = prompt[nodeId].inputs.negative[0] const negativeeNodeId = prompt[nodeId].inputs.negative[0]