diff --git a/src/config.ts b/src/config.ts index 71cab73..fbc797f 100644 --- a/src/config.ts +++ b/src/config.ts @@ -117,6 +117,7 @@ export interface Config extends PromptConfig { allowAnlas?: boolean | number endpoint?: string headers?: Dict + maxIteration?: number maxRetryCount?: number requestTimeout?: number recallTimeout?: number @@ -201,6 +202,7 @@ export const Config = Schema.intersect([ Schema.const('default').description('发送图片和关键信息'), Schema.const('verbose').description('发送全部信息'), ]).description('输出方式。').default('default'), + maxIteration: Schema.natural().description('允许的最大绘制次数。').default(1), maxRetryCount: Schema.natural().description('连接失败时最大的重试次数。').default(3), requestTimeout: Schema.number().role('time').description('当请求超过这个时间时会中止并提示超时。').default(Time.minute), recallTimeout: Schema.number().role('time').description('图片发送后自动撤回的时间 (设置为 0 以禁用此功能)。').default(0), diff --git a/src/index.ts b/src/index.ts index 9e1f9ac..19c9c57 100644 --- a/src/index.ts +++ b/src/index.ts @@ -103,9 +103,14 @@ export function apply(ctx: Context, config: Config) { .option('strength', '-N ', { hidden: restricted }) .option('undesired', '-u ') .option('noTranslator', '-T', { hidden: () => !ctx.translator || !config.translator }) + .option('iterations', '-i ', { fallback: 1, hidden: () => config.maxIteration <= 1 }) .action(async ({ session, options }, input) => { if (!input?.trim()) return session.execute('help novelai') + if (options.iterations && options.iterations > config.maxIteration) { + return session.text('.exceed-max-iteration', [config.maxIteration]) + } + let imgUrl: string, image: ImageData if (!restricted(session)) { input = segment.transform(input, { @@ -209,13 +214,14 @@ export function apply(ctx: Context, config: Config) { }) } - const id = Math.random().toString(36).slice(2) + const getRandomId = () => Math.random().toString(36).slice(2) + const iterations = Array(options.iterations).fill(0).map(getRandomId) if (config.maxConcurrency) { const store = tasks[session.cid] ||= new Set() if (store.size >= config.maxConcurrency) { return session.text('.concurrent-jobs') } else { - store.add(id) + iterations.forEach((id) => store.add(id)) } } @@ -223,8 +229,8 @@ export function apply(ctx: Context, config: Config) { ? session.text('.pending', [globalTasks.size]) : session.text('.waiting')) - globalTasks.add(id) - const cleanUp = () => { + iterations.forEach((id) => globalTasks.add(id)) + const cleanUp = (id: string) => { tasks[session.cid]?.delete(id) globalTasks.delete(id) } @@ -240,7 +246,7 @@ export function apply(ctx: Context, config: Config) { } })() - const data = (() => { + const getPayload = () => { if (config.type !== 'sd-webui') { parameters.sampler = sampler.sd2nai(options.sampler) parameters.image = image?.base64 // NovelAI / NAIFU accepts bare base64 encoded image @@ -263,86 +269,97 @@ export function apply(ctx: Context, config: Config) { denoising_strength: 'strength', }), } - })() + } - const request = () => ctx.http.axios(trimSlash(config.endpoint) + path, { - method: 'POST', - timeout: config.requestTimeout, - headers: { - ...config.headers, - authorization: 'Bearer ' + token, - }, - data, - }).then((res) => { - if (config.type === 'sd-webui') { - return stripDataPrefix((res.data as StableDiffusionWebUI.Response).images[0]) + const iterate = async () => { + const request = () => ctx.http.axios(trimSlash(config.endpoint) + path, { + method: 'POST', + timeout: config.requestTimeout, + headers: { + ...config.headers, + authorization: 'Bearer ' + token, + }, + data: getPayload(), + }).then((res) => { + if (config.type === 'sd-webui') { + return stripDataPrefix((res.data as StableDiffusionWebUI.Response).images[0]) + } + // event: newImage + // id: 1 + // data: + return res.data?.slice(27) + }) + + let base64: string, count = 0 + while (true) { + try { + base64 = await request() + break + } catch (err) { + if (Quester.isAxiosError(err)) { + if (err.code && err.code !== 'ETIMEDOUT' && ++count < config.maxRetryCount) { + continue + } + } + + return await session.send(handleError(session, err)) + } } - // event: newImage - // id: 1 - // data: - return res.data?.slice(27) - }) - let base64: string, count = 0 - while (true) { - try { - base64 = await request() - cleanUp() - break - } catch (err) { - if (Quester.isAxiosError(err)) { - if (err.code && err.code !== 'ETIMEDOUT' && ++count < config.maxRetryCount) { - continue + if (!base64.trim()) return await session.send(session.text('.empty-response')) + + function getContent() { + if (config.output === 'minimal') return segment.image('base64://' + base64) + const attrs = { + userId: session.userId, + nickname: session.author?.nickname || session.username, + } + const result = segment('figure') + const lines = [`seed = ${parameters.seed}`] + if (config.output === 'verbose') { + if (!thirdParty()) { + lines.push(`model = ${model}`) + } + lines.push( + `sampler = ${options.sampler}`, + `steps = ${parameters.steps}`, + `scale = ${parameters.scale}`, + ) + if (parameters.image) { + lines.push( + `strength = ${parameters.strength}`, + `noise = ${parameters.noise}`, + ) } } - cleanUp() - return handleError(session, err) + result.children.push(segment('message', attrs, lines.join('\n'))) + result.children.push(segment('message', attrs, `prompt = ${prompt}`)) + if (config.output === 'verbose') { + result.children.push(segment('message', attrs, `undesired = ${uc}`)) + } + result.children.push(segment('message', attrs, segment.image('base64://' + base64))) + return result + } + + const messageIds = await session.send(getContent()) + if (messageIds.length && config.recallTimeout) { + ctx.setTimeout(() => { + for (const id of messageIds) { + session.bot.deleteMessage(session.channelId, id) + } + }, config.recallTimeout) } } - if (!base64.trim()) return session.text('.empty-response') - - function getContent() { - if (config.output === 'minimal') return segment.image('base64://' + base64) - const attrs = { - userId: session.userId, - nickname: session.author?.nickname || session.username, + while (iterations.length) { + try { + await iterate() + cleanUp(iterations.pop()) + parameters.seed++ + } catch (err) { + iterations.forEach(cleanUp) + throw err } - const result = segment('figure') - const lines = [`seed = ${seed}`] - if (config.output === 'verbose') { - if (!thirdParty()) { - lines.push(`model = ${model}`) - } - lines.push( - `sampler = ${options.sampler}`, - `steps = ${parameters.steps}`, - `scale = ${parameters.scale}`, - ) - if (parameters.image) { - lines.push( - `strength = ${parameters.strength}`, - `noise = ${parameters.noise}`, - ) - } - } - result.children.push(segment('message', attrs, lines.join('\n'))) - result.children.push(segment('message', attrs, `prompt = ${prompt}`)) - if (config.output === 'verbose') { - result.children.push(segment('message', attrs, `undesired = ${uc}`)) - } - result.children.push(segment('message', attrs, segment.image('base64://' + base64))) - return result - } - - const ids = await session.send(getContent()) - - if (config.recallTimeout) { - ctx.setTimeout(() => { - for (const id of ids) { - session.bot.deleteMessage(session.channelId, id) - } - }, config.recallTimeout) } }) diff --git a/src/locales/zh-CN.yml b/src/locales/zh-CN.yml index 24ff2a6..1e20f62 100644 --- a/src/locales/zh-CN.yml +++ b/src/locales/zh-CN.yml @@ -19,8 +19,10 @@ commands: noise: 图片噪声强度 undesired: 排除标签 noTranslator: 禁用自动翻译 + iterations: 设置绘制次数 messages: + exceed-max-iteration: 超过最大绘制次数。 expect-prompt: 请输入标签。 expect-image: 请输入图片。 latin-only: 只接受英文输入。