From b27a02cf92d4b8cf2aba738f60d1d75eb7fd994d Mon Sep 17 00:00:00 2001 From: MieMieMieeeee <34560903+MieMieMieeeee@users.noreply.github.com> Date: Mon, 17 Jun 2024 07:24:49 +0900 Subject: [PATCH] feat: add support for ComfyUI (#254) Co-authored-by: Shigma Co-authored-by: idranme <96647698+idranme@users.noreply.github.com> --- data/default-comfyui-i2i-wf.json | 122 +++++++++++++++++++++++++++++++ data/default-comfyui-t2i-wf.json | 107 +++++++++++++++++++++++++++ src/config.ts | 51 ++++++++++++- src/index.ts | 104 +++++++++++++++++++++++++- 4 files changed, 381 insertions(+), 3 deletions(-) create mode 100644 data/default-comfyui-i2i-wf.json create mode 100644 data/default-comfyui-t2i-wf.json diff --git a/data/default-comfyui-i2i-wf.json b/data/default-comfyui-i2i-wf.json new file mode 100644 index 0000000..d42da6f --- /dev/null +++ b/data/default-comfyui-i2i-wf.json @@ -0,0 +1,122 @@ +{ + "3": { + "inputs": { + "seed": 1, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 0.87, + "model": [ + "14", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "12", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "6": { + "inputs": { + "text": "", + "clip": [ + "14", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "", + "clip": [ + "14", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "14", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "10": { + "inputs": { + "image": "example.png", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "12": { + "inputs": { + "pixels": [ + "10", + 0 + ], + "vae": [ + "14", + 2 + ] + }, + "class_type": "VAEEncode", + "_meta": { + "title": "VAE Encode" + } + }, + "14": { + "inputs": { + "ckpt_name": "" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + } +} \ No newline at end of file diff --git a/data/default-comfyui-t2i-wf.json b/data/default-comfyui-t2i-wf.json new file mode 100644 index 0000000..16292cd --- /dev/null +++ b/data/default-comfyui-t2i-wf.json @@ -0,0 +1,107 @@ +{ + "3": { + "inputs": { + "seed": 1, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 0.87, + "model": [ + "14", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "16", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "6": { + "inputs": { + "text": "", + "clip": [ + "14", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "", + "clip": [ + "14", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "14", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "14": { + "inputs": { + "ckpt_name": "" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "16": { + "inputs": { + "width": 512, + "height": 800, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + } +} \ No newline at end of file diff --git a/src/config.ts b/src/config.ts index da554a8..3db60f4 100644 --- a/src/config.ts +++ b/src/config.ts @@ -33,6 +33,7 @@ type Orient = keyof typeof orientMap export const models = Object.keys(modelMap) as Model[] export const orients = Object.keys(orientMap) as Orient[] export const scheduler = ['native', 'karras', 'exponential', 'polyexponential'] as const +export const schedulerComfyUI = ['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform'] as const export namespace sampler { export const nai = { @@ -87,6 +88,31 @@ export namespace sampler { DDIM_ka: 'DDIM Karras', } + export const comfyui = { + euler: 'Euler', + euler_ancestral: 'Euler ancestral', + heun: 'Heun', + heunpp2: 'Heun++ 2', + dpm_2: 'DPM 2', + dpm_2_ancestral: 'DPM 2 ancestral', + lms: 'LMS', + dpm_fast: 'DPM fast', + dpm_adaptive: 'DPM adaptive', + dpmpp_2s_ancestral: 'DPM++ 2S ancestral', + dpmpp_sde: 'DPM++ SDE', + dpmpp_sde_gpu: 'DPM++ SDE GPU', + dpmpp_2m: 'DPM++ 2M', + dpmpp_2m_sde: 'DPM++ 2M SDE', + dpmpp_2m_sde_gpu: 'DPM++ 2M SDE GPU', + dpmpp_3m_sde: 'DPM++ 3M SDE', + dpmpp_3m_sde_gpu: 'DPM++ 3M SDE GPU', + ddpm: 'DDPM', + lcm: 'LCM', + ddim: 'DDIM', + uni_pc: 'UniPC', + uni_pc_bh2: 'UniPC BH2', + } + export function createSchema(map: Dict) { return Schema.union(Object.entries(map).map(([key, value]) => { return Schema.const(key).description(value) @@ -201,7 +227,7 @@ interface ParamConfig { } export interface Config extends PromptConfig, ParamConfig { - type: 'token' | 'login' | 'naifu' | 'sd-webui' | 'stable-horde' + type: 'token' | 'login' | 'naifu' | 'sd-webui' | 'stable-horde' | 'comfyui' token?: string email?: string password?: string @@ -220,6 +246,8 @@ export interface Config extends PromptConfig, ParamConfig { maxConcurrency?: number pollInterval?: number trustedWorkers?: boolean + workflowText2Image?: string + workflowImage2Image?: string } export const Config = Schema.intersect([ @@ -230,6 +258,7 @@ export const Config = Schema.intersect([ Schema.const('naifu').description('naifu'), Schema.const('sd-webui').description('sd-webui'), Schema.const('stable-horde').description('Stable Horde'), + Schema.const('comfyui').description('ComfyUI'), ]).default('token').description('登录方式。'), }).description('登录设置'), @@ -278,6 +307,12 @@ export const Config = Schema.intersect([ trustedWorkers: Schema.boolean().description('是否只请求可信任工作节点。').default(false), pollInterval: Schema.number().role('time').description('轮询进度间隔时长。').default(Time.second), }), + Schema.object({ + type: Schema.const('comfyui'), + endpoint: Schema.string().description('API 服务器地址。').required(), + headers: Schema.dict(String).role('table').description('要附加的额外请求头。'), + pollInterval: Schema.number().role('time').description('轮询进度间隔时长。').default(Time.second), + }), ]), Schema.object({ @@ -322,6 +357,20 @@ export const Config = Schema.intersect([ type: Schema.const('naifu').required(), sampler: sampler.createSchema(sampler.nai), }), + Schema.object({ + type: Schema.const('comfyui').required(), + sampler: sampler.createSchema(sampler.comfyui).description('默认的采样器。').required(), + model: Schema.string().description('默认的生成模型的文件名。').required(), + workflowText2Image: Schema.path({ + filters: [{ name: '', extensions: ['.json'] }], + allowCreate: true, + }).description('API 格式的文本到图像工作流。'), + workflowImage2Image: Schema.path({ + filters: [{ name: '', extensions: ['.json'] }], + allowCreate: true, + }).description('API 格式的图像到图像工作流。'), + scheduler: Schema.union(schedulerComfyUI).description('默认的调度器。').default('normal'), + }), Schema.intersect([ Schema.object({ model: Schema.union(models).loose().description('默认的生成模型。').default('nai-v3'), diff --git a/src/index.ts b/src/index.ts index c3aebba..66e36d7 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5,6 +5,8 @@ import { closestMultiple, download, forceDataPrefix, getImageSize, login, Networ import { } from '@koishijs/translator' import { } from '@koishijs/plugin-help' import AdmZip from 'adm-zip' +import { resolve } from 'path' +import { readFile } from 'fs/promises' export * from './config' @@ -301,12 +303,14 @@ export function apply(ctx: Context, config: Config) { return '/api/v2/generate/async' case 'naifu': return '/generate-stream' + case 'comfyui': + return '/prompt' default: return '/ai/generate-image' } })() - const getPayload = () => { + const getPayload = async () => { switch (config.type) { case 'login': case 'token': @@ -392,6 +396,76 @@ export function apply(ctx: Context, config: Config) { r2: true, } } + case 'comfyui': { + const workflowText2Image = config.workflowText2Image ? resolve(ctx.baseDir, config.workflowText2Image) : resolve(__dirname,'../data/default-comfyui-t2i-wf.json') + const workflowImage2Image = config.workflowImage2Image ? resolve(ctx.baseDir, config.workflowImage2Image) : resolve(__dirname,'../data/default-comfyui-i2i-wf.json') + const workflow = image ? workflowImage2Image : workflowText2Image + logger.debug('workflow:', workflow) + const prompt = JSON.parse(await readFile(workflow, 'utf8')) + + // have to upload image to the comfyui server first + if (image) { + const body = new FormData() + const capture = /^data:([\w/.+-]+);base64,(.*)$/.exec(image.dataUrl) + const [, mime,] = capture + + let name = Date.now().toString() + const ext = mime === 'image/jpeg' ? 'jpg' : mime === 'image/png' ? 'png' : '' + if (ext) name += `.${ext}` + const imageFile = new Blob([image.buffer], {type:mime}) + body.append("image", imageFile, name) + const res = await ctx.http(trimSlash(config.endpoint) + '/upload/image', { + method: 'POST', + headers: { + ...config.headers, + }, + data: body, + }) + if (res.status === 200) { + const data = res.data + let imagePath = data.name + if (data.subfolder) imagePath = data.subfolder + '/' + imagePath + + for (const nodeId in prompt) { + if (prompt[nodeId].class_type === 'LoadImage') { + prompt[nodeId].inputs.image = imagePath + break + } + } + } else { + throw new SessionError('commands.novelai.messages.unknown-error') + } + } + + // only change the first node in the workflow + for (const nodeId in prompt) { + if (prompt[nodeId].class_type === 'KSampler') { + prompt[nodeId].inputs.seed = parameters.seed + prompt[nodeId].inputs.steps = parameters.steps + prompt[nodeId].inputs.cfg = parameters.scale + prompt[nodeId].inputs.sampler_name = options.sampler + prompt[nodeId].inputs.denoise = options.strength ?? config.strength + prompt[nodeId].inputs.scheduler = options.scheduler ?? config.scheduler + const positiveNodeId = prompt[nodeId].inputs.positive[0] + const negativeeNodeId = prompt[nodeId].inputs.negative[0] + const latentImageNodeId = prompt[nodeId].inputs.latent_image[0] + prompt[positiveNodeId].inputs.text = parameters.prompt + prompt[negativeeNodeId].inputs.text = parameters.uc + prompt[latentImageNodeId].inputs.width = parameters.width + prompt[latentImageNodeId].inputs.height = parameters.height + prompt[latentImageNodeId].inputs.batch_size = parameters.n_samples + break + } + } + for (const nodeId in prompt) { + if (prompt[nodeId].class_type === 'CheckpointLoaderSimple') { + prompt[nodeId].inputs.ckpt_name = options.model ?? config.model + break + } + } + logger.debug('prompt:', prompt) + return { prompt } + } } } @@ -418,7 +492,7 @@ export function apply(ctx: Context, config: Config) { ...config.headers, ...getHeaders(), }, - data: getPayload(), + data: await getPayload(), }) if (config.type === 'sd-webui') { @@ -453,6 +527,32 @@ export function apply(ctx: Context, config: Config) { const b64 = Buffer.from(imgRes.data).toString('base64') return forceDataPrefix(b64, imgRes.headers.get('content-type')) } + if (config.type === 'comfyui') { + // get filenames from history + const promptId = res.data.prompt_id + const check = () => ctx.http.get(trimSlash(config.endpoint) + '/history/' + promptId) + .then((res) => res[promptId] && res[promptId].outputs) + const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)) + let outputs + while (!(outputs = await check())) { + await sleep(config.pollInterval) + } + // get images by filename + const imagesOutput: { data: ArrayBuffer, mime: string }[] = []; + for (const nodeId in outputs) { + const nodeOutput = outputs[nodeId] + if ('images' in nodeOutput) { + for (const image of nodeOutput['images']) { + const urlValues = new URLSearchParams({ filename: image['filename'], subfolder: image['subfolder'], type: image['type'] }).toString() + const imgRes = await ctx.http(trimSlash(config.endpoint) + '/view?' + urlValues) + imagesOutput.push({ data: imgRes.data, mime: imgRes.headers.get('content-type') }) + break + } + } + } + // return first image + return forceDataPrefix(Buffer.from(imagesOutput[0].data).toString('base64'), imagesOutput[0].mime) + } // event: newImage // id: 1 // data: