feat: add support for ComfyUI (#254)

Co-authored-by: Shigma <shigma10826@gmail.com>
Co-authored-by: idranme <96647698+idranme@users.noreply.github.com>
This commit is contained in:
MieMieMieeeee 2024-06-17 07:24:49 +09:00 committed by GitHub
parent 39cdf4fc93
commit b27a02cf92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 381 additions and 3 deletions

View File

@ -0,0 +1,122 @@
{
"3": {
"inputs": {
"seed": 1,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 0.87,
"model": [
"14",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"12",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"6": {
"inputs": {
"text": "",
"clip": [
"14",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "",
"clip": [
"14",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"14",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"10": {
"inputs": {
"image": "example.png",
"upload": "image"
},
"class_type": "LoadImage",
"_meta": {
"title": "Load Image"
}
},
"12": {
"inputs": {
"pixels": [
"10",
0
],
"vae": [
"14",
2
]
},
"class_type": "VAEEncode",
"_meta": {
"title": "VAE Encode"
}
},
"14": {
"inputs": {
"ckpt_name": ""
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
}
}

View File

@ -0,0 +1,107 @@
{
"3": {
"inputs": {
"seed": 1,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 0.87,
"model": [
"14",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"16",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"6": {
"inputs": {
"text": "",
"clip": [
"14",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "",
"clip": [
"14",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"14",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"14": {
"inputs": {
"ckpt_name": ""
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"16": {
"inputs": {
"width": 512,
"height": 800,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
}
}

View File

@ -33,6 +33,7 @@ type Orient = keyof typeof orientMap
export const models = Object.keys(modelMap) as Model[]
export const orients = Object.keys(orientMap) as Orient[]
export const scheduler = ['native', 'karras', 'exponential', 'polyexponential'] as const
export const schedulerComfyUI = ['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform'] as const
export namespace sampler {
export const nai = {
@ -87,6 +88,31 @@ export namespace sampler {
DDIM_ka: 'DDIM Karras',
}
export const comfyui = {
euler: 'Euler',
euler_ancestral: 'Euler ancestral',
heun: 'Heun',
heunpp2: 'Heun++ 2',
dpm_2: 'DPM 2',
dpm_2_ancestral: 'DPM 2 ancestral',
lms: 'LMS',
dpm_fast: 'DPM fast',
dpm_adaptive: 'DPM adaptive',
dpmpp_2s_ancestral: 'DPM++ 2S ancestral',
dpmpp_sde: 'DPM++ SDE',
dpmpp_sde_gpu: 'DPM++ SDE GPU',
dpmpp_2m: 'DPM++ 2M',
dpmpp_2m_sde: 'DPM++ 2M SDE',
dpmpp_2m_sde_gpu: 'DPM++ 2M SDE GPU',
dpmpp_3m_sde: 'DPM++ 3M SDE',
dpmpp_3m_sde_gpu: 'DPM++ 3M SDE GPU',
ddpm: 'DDPM',
lcm: 'LCM',
ddim: 'DDIM',
uni_pc: 'UniPC',
uni_pc_bh2: 'UniPC BH2',
}
export function createSchema(map: Dict<string>) {
return Schema.union(Object.entries(map).map(([key, value]) => {
return Schema.const(key).description(value)
@ -201,7 +227,7 @@ interface ParamConfig {
}
export interface Config extends PromptConfig, ParamConfig {
type: 'token' | 'login' | 'naifu' | 'sd-webui' | 'stable-horde'
type: 'token' | 'login' | 'naifu' | 'sd-webui' | 'stable-horde' | 'comfyui'
token?: string
email?: string
password?: string
@ -220,6 +246,8 @@ export interface Config extends PromptConfig, ParamConfig {
maxConcurrency?: number
pollInterval?: number
trustedWorkers?: boolean
workflowText2Image?: string
workflowImage2Image?: string
}
export const Config = Schema.intersect([
@ -230,6 +258,7 @@ export const Config = Schema.intersect([
Schema.const('naifu').description('naifu'),
Schema.const('sd-webui').description('sd-webui'),
Schema.const('stable-horde').description('Stable Horde'),
Schema.const('comfyui').description('ComfyUI'),
]).default('token').description('登录方式。'),
}).description('登录设置'),
@ -278,6 +307,12 @@ export const Config = Schema.intersect([
trustedWorkers: Schema.boolean().description('是否只请求可信任工作节点。').default(false),
pollInterval: Schema.number().role('time').description('轮询进度间隔时长。').default(Time.second),
}),
Schema.object({
type: Schema.const('comfyui'),
endpoint: Schema.string().description('API 服务器地址。').required(),
headers: Schema.dict(String).role('table').description('要附加的额外请求头。'),
pollInterval: Schema.number().role('time').description('轮询进度间隔时长。').default(Time.second),
}),
]),
Schema.object({
@ -322,6 +357,20 @@ export const Config = Schema.intersect([
type: Schema.const('naifu').required(),
sampler: sampler.createSchema(sampler.nai),
}),
Schema.object({
type: Schema.const('comfyui').required(),
sampler: sampler.createSchema(sampler.comfyui).description('默认的采样器。').required(),
model: Schema.string().description('默认的生成模型的文件名。').required(),
workflowText2Image: Schema.path({
filters: [{ name: '', extensions: ['.json'] }],
allowCreate: true,
}).description('API 格式的文本到图像工作流。'),
workflowImage2Image: Schema.path({
filters: [{ name: '', extensions: ['.json'] }],
allowCreate: true,
}).description('API 格式的图像到图像工作流。'),
scheduler: Schema.union(schedulerComfyUI).description('默认的调度器。').default('normal'),
}),
Schema.intersect([
Schema.object({
model: Schema.union(models).loose().description('默认的生成模型。').default('nai-v3'),

View File

@ -5,6 +5,8 @@ import { closestMultiple, download, forceDataPrefix, getImageSize, login, Networ
import { } from '@koishijs/translator'
import { } from '@koishijs/plugin-help'
import AdmZip from 'adm-zip'
import { resolve } from 'path'
import { readFile } from 'fs/promises'
export * from './config'
@ -301,12 +303,14 @@ export function apply(ctx: Context, config: Config) {
return '/api/v2/generate/async'
case 'naifu':
return '/generate-stream'
case 'comfyui':
return '/prompt'
default:
return '/ai/generate-image'
}
})()
const getPayload = () => {
const getPayload = async () => {
switch (config.type) {
case 'login':
case 'token':
@ -392,6 +396,76 @@ export function apply(ctx: Context, config: Config) {
r2: true,
}
}
case 'comfyui': {
const workflowText2Image = config.workflowText2Image ? resolve(ctx.baseDir, config.workflowText2Image) : resolve(__dirname,'../data/default-comfyui-t2i-wf.json')
const workflowImage2Image = config.workflowImage2Image ? resolve(ctx.baseDir, config.workflowImage2Image) : resolve(__dirname,'../data/default-comfyui-i2i-wf.json')
const workflow = image ? workflowImage2Image : workflowText2Image
logger.debug('workflow:', workflow)
const prompt = JSON.parse(await readFile(workflow, 'utf8'))
// have to upload image to the comfyui server first
if (image) {
const body = new FormData()
const capture = /^data:([\w/.+-]+);base64,(.*)$/.exec(image.dataUrl)
const [, mime,] = capture
let name = Date.now().toString()
const ext = mime === 'image/jpeg' ? 'jpg' : mime === 'image/png' ? 'png' : ''
if (ext) name += `.${ext}`
const imageFile = new Blob([image.buffer], {type:mime})
body.append("image", imageFile, name)
const res = await ctx.http(trimSlash(config.endpoint) + '/upload/image', {
method: 'POST',
headers: {
...config.headers,
},
data: body,
})
if (res.status === 200) {
const data = res.data
let imagePath = data.name
if (data.subfolder) imagePath = data.subfolder + '/' + imagePath
for (const nodeId in prompt) {
if (prompt[nodeId].class_type === 'LoadImage') {
prompt[nodeId].inputs.image = imagePath
break
}
}
} else {
throw new SessionError('commands.novelai.messages.unknown-error')
}
}
// only change the first node in the workflow
for (const nodeId in prompt) {
if (prompt[nodeId].class_type === 'KSampler') {
prompt[nodeId].inputs.seed = parameters.seed
prompt[nodeId].inputs.steps = parameters.steps
prompt[nodeId].inputs.cfg = parameters.scale
prompt[nodeId].inputs.sampler_name = options.sampler
prompt[nodeId].inputs.denoise = options.strength ?? config.strength
prompt[nodeId].inputs.scheduler = options.scheduler ?? config.scheduler
const positiveNodeId = prompt[nodeId].inputs.positive[0]
const negativeeNodeId = prompt[nodeId].inputs.negative[0]
const latentImageNodeId = prompt[nodeId].inputs.latent_image[0]
prompt[positiveNodeId].inputs.text = parameters.prompt
prompt[negativeeNodeId].inputs.text = parameters.uc
prompt[latentImageNodeId].inputs.width = parameters.width
prompt[latentImageNodeId].inputs.height = parameters.height
prompt[latentImageNodeId].inputs.batch_size = parameters.n_samples
break
}
}
for (const nodeId in prompt) {
if (prompt[nodeId].class_type === 'CheckpointLoaderSimple') {
prompt[nodeId].inputs.ckpt_name = options.model ?? config.model
break
}
}
logger.debug('prompt:', prompt)
return { prompt }
}
}
}
@ -418,7 +492,7 @@ export function apply(ctx: Context, config: Config) {
...config.headers,
...getHeaders(),
},
data: getPayload(),
data: await getPayload(),
})
if (config.type === 'sd-webui') {
@ -453,6 +527,32 @@ export function apply(ctx: Context, config: Config) {
const b64 = Buffer.from(imgRes.data).toString('base64')
return forceDataPrefix(b64, imgRes.headers.get('content-type'))
}
if (config.type === 'comfyui') {
// get filenames from history
const promptId = res.data.prompt_id
const check = () => ctx.http.get(trimSlash(config.endpoint) + '/history/' + promptId)
.then((res) => res[promptId] && res[promptId].outputs)
const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms))
let outputs
while (!(outputs = await check())) {
await sleep(config.pollInterval)
}
// get images by filename
const imagesOutput: { data: ArrayBuffer, mime: string }[] = [];
for (const nodeId in outputs) {
const nodeOutput = outputs[nodeId]
if ('images' in nodeOutput) {
for (const image of nodeOutput['images']) {
const urlValues = new URLSearchParams({ filename: image['filename'], subfolder: image['subfolder'], type: image['type'] }).toString()
const imgRes = await ctx.http(trimSlash(config.endpoint) + '/view?' + urlValues)
imagesOutput.push({ data: imgRes.data, mime: imgRes.headers.get('content-type') })
break
}
}
}
// return first image
return forceDataPrefix(Buffer.from(imagesOutput[0].data).toString('base64'), imagesOutput[0].mime)
}
// event: newImage
// id: 1
// data: