mirror of
https://github.com/koishijs/novelai-bot
synced 2025-01-09 20:11:56 +08:00
feat: adapt sampler with type
This commit is contained in:
parent
0074d4e2b4
commit
1c7ff267de
156
src/index.ts
156
src/index.ts
@ -1,6 +1,6 @@
|
||||
import { Context, Dict, Logger, Quester, Schema, segment, Session, Time, trimSlash } from 'koishi'
|
||||
import { StableDiffusionWebUI } from './types'
|
||||
import { download, getImageSize, login, NetworkError, project, resizeInput, samplersMapN2S } from './utils'
|
||||
import { download, getImageSize, login, NetworkError, project, resizeInput } from './utils'
|
||||
import {} from '@koishijs/plugin-help'
|
||||
|
||||
export const reactive = true
|
||||
@ -34,11 +34,30 @@ const badAnatomy = [
|
||||
|
||||
type Model = keyof typeof modelMap
|
||||
type Orient = keyof typeof orientMap
|
||||
type Sampler = typeof samplers[number]
|
||||
|
||||
const models = Object.keys(modelMap) as Model[]
|
||||
const orients = Object.keys(orientMap) as Orient[]
|
||||
const samplers = ['k_euler_ancestral', 'k_euler', 'k_lms', 'plms', 'ddim'] as const
|
||||
const naiSamplers = ['k_euler_ancestral', 'k_euler', 'k_lms', 'plms', 'ddim']
|
||||
const sdSamplers = {
|
||||
'k_euler_a': 'Euler a',
|
||||
'k_euler': 'Euler',
|
||||
'k_lms': 'LMS',
|
||||
'k_heun': 'Heun',
|
||||
'k_dpm_2': 'DPM2',
|
||||
'k_dpm_2_a': 'DPM2 a',
|
||||
'k_dpm_fast': 'DPM fast',
|
||||
'k_dpm_ad': 'DPM adaptive',
|
||||
'k_lms_ka': 'LMS Karras',
|
||||
'k_dpm_2_ka': 'DPM2 Karras',
|
||||
'k_dpm_2_a_ka': 'DPM2 a Karras',
|
||||
'ddim': 'DDIM',
|
||||
'plms': 'PLMS',
|
||||
}
|
||||
|
||||
function toNAISampler(sampler: string): string {
|
||||
if (naiSamplers.includes(sampler)) return sampler
|
||||
return 'k_euler_ancestral'
|
||||
}
|
||||
|
||||
export interface Config {
|
||||
type: 'token' | 'login' | 'naifu' | 'sd-webui'
|
||||
@ -47,7 +66,7 @@ export interface Config {
|
||||
password: string
|
||||
model?: Model
|
||||
orient?: Orient
|
||||
sampler?: Sampler
|
||||
sampler?: string
|
||||
anatomy?: boolean
|
||||
output?: 'minimal' | 'default' | 'verbose'
|
||||
allowAnlas?: boolean | number
|
||||
@ -71,52 +90,70 @@ export const Config = Schema.intersect([
|
||||
Schema.const('sd-webui' as const).description('sd-webui'),
|
||||
] as const).description('登录方式'),
|
||||
}).description('登录设置'),
|
||||
|
||||
Schema.union([
|
||||
Schema.object({
|
||||
type: Schema.const('token' as const),
|
||||
token: Schema.string().description('授权令牌。').role('secret').required(),
|
||||
endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'),
|
||||
headers: Schema.dict(String).description('要附加的额外请求头。').default({
|
||||
'referer': 'https://novelai.net/',
|
||||
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36',
|
||||
Schema.intersect([
|
||||
Schema.union([
|
||||
Schema.object({
|
||||
type: Schema.const('token'),
|
||||
token: Schema.string().description('授权令牌。').role('secret').required(),
|
||||
}),
|
||||
Schema.object({
|
||||
type: Schema.const('login'),
|
||||
email: Schema.string().description('用户名。').required(),
|
||||
password: Schema.string().description('密码。').role('secret').required(),
|
||||
}),
|
||||
]),
|
||||
Schema.object({
|
||||
endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'),
|
||||
headers: Schema.dict(String).description('要附加的额外请求头。').default({
|
||||
'referer': 'https://novelai.net/',
|
||||
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36',
|
||||
}),
|
||||
allowAnlas: Schema.union([
|
||||
Schema.const(true).description('允许'),
|
||||
Schema.const(false).description('禁止'),
|
||||
Schema.natural().description('权限等级').default(1),
|
||||
]).default(true).description('是否允许使用点数。禁用后部分功能 (图片增强和手动设置某些参数) 将无法使用。'),
|
||||
}),
|
||||
}),
|
||||
]),
|
||||
Schema.object({
|
||||
type: Schema.const('login' as const),
|
||||
email: Schema.string().description('用户名。').required(),
|
||||
password: Schema.string().description('密码。').role('secret').required(),
|
||||
endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'),
|
||||
headers: Schema.dict(String).description('要附加的额外请求头。').default({
|
||||
'referer': 'https://novelai.net/',
|
||||
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36',
|
||||
}),
|
||||
}),
|
||||
Schema.object({
|
||||
type: Schema.const('naifu' as const),
|
||||
type: Schema.const('naifu'),
|
||||
token: Schema.string().description('授权令牌。').role('secret'),
|
||||
endpoint: Schema.string().description('API 服务器地址。').required(),
|
||||
headers: Schema.dict(String).description('要附加的额外请求头。'),
|
||||
}),
|
||||
Schema.object({
|
||||
type: Schema.const('sd-webui' as const),
|
||||
type: Schema.const('sd-webui'),
|
||||
endpoint: Schema.string().description('API 服务器地址。').required(),
|
||||
headers: Schema.dict(String).description('要附加的额外请求头。'),
|
||||
}),
|
||||
]),
|
||||
|
||||
Schema.union([
|
||||
Schema.object({
|
||||
type: Schema.const('sd-webui'),
|
||||
sampler: Schema.union(Object.entries(sdSamplers).map(([key, value]) => {
|
||||
return Schema.const(key).description(value)
|
||||
})).description('默认的采样器。').default('k_euler_a'),
|
||||
}).description('功能设置'),
|
||||
Schema.object({
|
||||
type: Schema.const('naifu'),
|
||||
sampler: Schema.union(naiSamplers).description('默认的采样器。').default('k_euler_ancestral'),
|
||||
}).description('功能设置'),
|
||||
Schema.object({
|
||||
model: Schema.union(models).description('默认的生成模型。').default('nai'),
|
||||
sampler: Schema.union(naiSamplers).description('默认的采样器。').default('k_euler_ancestral'),
|
||||
}).description('功能设置'),
|
||||
] as const),
|
||||
|
||||
Schema.object({
|
||||
model: Schema.union(models).description('默认的生成模型。').default('nai'),
|
||||
orient: Schema.union(orients).description('默认的图片方向。').default('portrait'),
|
||||
sampler: Schema.union(samplers).description('默认的采样器。').default('k_euler_ancestral'),
|
||||
anatomy: Schema.boolean().default(true).description('是否过滤不合理构图。'),
|
||||
output: Schema.union([
|
||||
Schema.const('minimal' as const).description('只发送图片'),
|
||||
Schema.const('default' as const).description('发送图片和关键信息'),
|
||||
Schema.const('verbose' as const).description('发送全部信息'),
|
||||
Schema.const('minimal').description('只发送图片'),
|
||||
Schema.const('default').description('发送图片和关键信息'),
|
||||
Schema.const('verbose').description('发送全部信息'),
|
||||
]).description('输出方式。').default('default'),
|
||||
allowAnlas: Schema.union([
|
||||
Schema.const(true).description('允许'),
|
||||
Schema.const(false).description('禁止'),
|
||||
Schema.natural().description('权限等级').default(1),
|
||||
]).default(true).description('是否允许使用点数。禁用后部分功能 (图片增强和手动设置某些参数) 将无法使用。'),
|
||||
basePrompt: Schema.string().role('textarea').description('默认附加的标签。').default('masterpiece, best quality'),
|
||||
negativePrompt: Schema.string().role('textarea').description('默认附加的反向标签。').default([lowQuality, badAnatomy].join(', ')),
|
||||
forbidden: Schema.string().role('textarea').description('违禁词列表。含有违禁词的请求将被拒绝。').default(''),
|
||||
@ -124,8 +161,8 @@ export const Config = Schema.intersect([
|
||||
requestTimeout: Schema.number().role('time').description('当请求超过这个时间时会中止并提示超时。').default(Time.minute),
|
||||
recallTimeout: Schema.number().role('time').description('图片发送后自动撤回的时间 (设置为 0 以禁用此功能)。').default(0),
|
||||
maxConcurrency: Schema.number().description('单个频道下的最大并发数量 (设置为 0 以禁用此功能)。').default(0),
|
||||
}).description('功能设置'),
|
||||
] as const) as Schema<Config>
|
||||
}),
|
||||
]) as Schema<Config>
|
||||
|
||||
function handleError(session: Session, err: Error) {
|
||||
if (Quester.isAxiosError(err)) {
|
||||
@ -196,7 +233,7 @@ export function apply(ctx: Context, config: Config) {
|
||||
.option('enhance', '-e', { hidden })
|
||||
.option('model', '-m <model>', { type: models })
|
||||
.option('orient', '-o <orient>', { type: orients })
|
||||
.option('sampler', '-s <sampler>', { type: samplers })
|
||||
.option('sampler', '-s <sampler>')
|
||||
.option('seed', '-x <seed:number>')
|
||||
.option('steps', '-t <step:number>', { hidden })
|
||||
.option('scale', '-c <scale:number>')
|
||||
@ -285,7 +322,6 @@ export function apply(ctx: Context, config: Config) {
|
||||
const parameters: Dict = {
|
||||
seed,
|
||||
n_samples: 1,
|
||||
sampler: options.sampler,
|
||||
uc: undesired.join(', '),
|
||||
ucPreset: 0,
|
||||
}
|
||||
@ -358,6 +394,29 @@ export function apply(ctx: Context, config: Config) {
|
||||
globalTasks.delete(id)
|
||||
}
|
||||
|
||||
function getPostData() {
|
||||
if (config.type !== 'sd-webui') {
|
||||
parameters.sampler = toNAISampler(options.sampler)
|
||||
return config.type === 'naifu'
|
||||
? { ...parameters, prompt: input }
|
||||
: { model, input, parameters }
|
||||
}
|
||||
|
||||
return {
|
||||
prompt: input,
|
||||
sampler_index: sdSamplers[options.sampler],
|
||||
...project(parameters, {
|
||||
n_samples: 'n_samples',
|
||||
seed: 'seed',
|
||||
negative_prompt: 'uc',
|
||||
cfg_scale: 'scale',
|
||||
steps: 'steps',
|
||||
width: 'width',
|
||||
height: 'height',
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
const path = config.type === 'sd-webui' ? '/sdapi/v1/txt2img' : config.type === 'naifu' ? '/generate-stream' : '/ai/generate-image'
|
||||
const request = () => ctx.http.axios(trimSlash(config.endpoint) + path, {
|
||||
method: 'POST',
|
||||
@ -366,23 +425,7 @@ export function apply(ctx: Context, config: Config) {
|
||||
...config.headers,
|
||||
authorization: 'Bearer ' + token,
|
||||
},
|
||||
data: config.type === 'sd-webui'
|
||||
? {
|
||||
prompt: input,
|
||||
sampler_index: samplersMapN2S(parameters.sampler),
|
||||
...project(parameters, {
|
||||
n_samples: 'n_samples',
|
||||
seed: 'seed',
|
||||
negative_prompt: 'uc',
|
||||
cfg_scale: 'scale',
|
||||
steps: 'steps',
|
||||
width: 'width',
|
||||
height: 'height',
|
||||
}),
|
||||
}
|
||||
: config.type === 'naifu'
|
||||
? { ...parameters, prompt: input }
|
||||
: { model, input, parameters },
|
||||
data: getPostData(),
|
||||
}).then((res) => {
|
||||
if (config.type === 'sd-webui') {
|
||||
return (res.data as StableDiffusionWebUI.Response).images[0]
|
||||
@ -452,5 +495,6 @@ export function apply(ctx: Context, config: Config) {
|
||||
cmd._options.model.fallback = config.model
|
||||
cmd._options.orient.fallback = config.orient
|
||||
cmd._options.sampler.fallback = config.sampler
|
||||
cmd._options.sampler.type = config.type === 'sd-webui' ? Object.keys(sdSamplers) : naiSamplers
|
||||
}, { immediate: true })
|
||||
}
|
||||
|
@ -8,9 +8,9 @@ commands:
|
||||
|
||||
options:
|
||||
enhance: Image Enhance Mode
|
||||
model: Set Model for Generation (safe, nai, furry)
|
||||
orient: Set Image Orientation (portrait, landscape, square)
|
||||
sampler: Set Sampler (k_euler_ancestral, k_euler, k_lms, plms, ddim)
|
||||
model: Set Model for Generation
|
||||
orient: Set Image Orientation
|
||||
sampler: Set Sampler
|
||||
anatomy.true: Filter Anatomically Incorrect Images
|
||||
anatomy.false: Allow Anatomically Incorrect Images
|
||||
seed: Set Random Seed
|
||||
|
@ -8,9 +8,9 @@ commands:
|
||||
|
||||
options:
|
||||
enhance: Mode d'amélioration de l'image
|
||||
model: Définir le modèle pour génération (safe, nai, furry)
|
||||
orient: Définir l'orientation de l'image (portrait, landscape, square)
|
||||
sampler: Définir l'échantillonneur (k_euler_ancestral, k_euler, k_lms, plms, ddim)
|
||||
model: Définir le modèle pour génération
|
||||
orient: Définir l'orientation de l'image
|
||||
sampler: Définir l'échantillonneur
|
||||
anatomy.true: Filtrer les images anatomiquement incorrectes
|
||||
anatomy.false: Autoriser les images anatomiquement incorrectes
|
||||
seed: Définir une graine aléatoire
|
||||
|
@ -8,9 +8,9 @@ commands:
|
||||
|
||||
options:
|
||||
enhance: 圖片增強模式
|
||||
model: 設定生成模型 (safe, nai, furry)
|
||||
orient: 設定圖片方向 (portrait, landscape, square)
|
||||
sampler: 設定取樣器 (k_euler_ancestral, k_euler, k_lms, plms, ddim)
|
||||
model: 設定生成模型
|
||||
orient: 設定圖片方向
|
||||
sampler: 設定取樣器
|
||||
anatomy.true: 過濾不合理構圖
|
||||
anatomy.false: 允許不合理構圖
|
||||
seed: 設置隨機種子
|
||||
|
@ -8,9 +8,9 @@ commands:
|
||||
|
||||
options:
|
||||
enhance: 图片增强模式
|
||||
model: 设定生成模型 (safe, nai, furry)
|
||||
orient: 设定图片方向 (portrait, landscape, square)
|
||||
sampler: 设置采样器 (k_euler_ancestral, k_euler, k_lms, plms, ddim)
|
||||
model: 设定生成模型
|
||||
orient: 设定图片方向
|
||||
sampler: 设置采样器
|
||||
anatomy.true: 过滤不合理构图
|
||||
anatomy.false: 允许不合理构图
|
||||
seed: 设置随机种子
|
||||
|
12
src/utils.ts
12
src/utils.ts
@ -184,15 +184,3 @@ export function resizeInput(size: Size): Size {
|
||||
return { width, height }
|
||||
}
|
||||
}
|
||||
|
||||
export function samplersMapN2S(sampler: string): string {
|
||||
switch (sampler) {
|
||||
case 'k_euler_ancestral': return 'Euler a'
|
||||
case 'k_euler': return 'Euler'
|
||||
case 'k_lms': return 'LMS'
|
||||
case 'plms': return 'PLMS'
|
||||
case 'ddim': return 'DDIM'
|
||||
}
|
||||
|
||||
return 'Euler a'
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user