feat: add --iteration option for generate multiple images (#131)

Co-authored-by: Shigma <shigma10826@gmail.com>
This commit is contained in:
Maiko Sinkyaet Tan 2022-11-26 00:32:37 +08:00 committed by GitHub
parent 4b7f0b19fb
commit cce62cf49f
3 changed files with 97 additions and 76 deletions

View File

@ -117,6 +117,7 @@ export interface Config extends PromptConfig {
allowAnlas?: boolean | number
endpoint?: string
headers?: Dict<string>
maxIteration?: number
maxRetryCount?: number
requestTimeout?: number
recallTimeout?: number
@ -201,6 +202,7 @@ export const Config = Schema.intersect([
Schema.const('default').description('发送图片和关键信息'),
Schema.const('verbose').description('发送全部信息'),
]).description('输出方式。').default('default'),
maxIteration: Schema.natural().description('允许的最大绘制次数。').default(1),
maxRetryCount: Schema.natural().description('连接失败时最大的重试次数。').default(3),
requestTimeout: Schema.number().role('time').description('当请求超过这个时间时会中止并提示超时。').default(Time.minute),
recallTimeout: Schema.number().role('time').description('图片发送后自动撤回的时间 (设置为 0 以禁用此功能)。').default(0),

View File

@ -103,9 +103,14 @@ export function apply(ctx: Context, config: Config) {
.option('strength', '-N <strength:number>', { hidden: restricted })
.option('undesired', '-u <undesired>')
.option('noTranslator', '-T', { hidden: () => !ctx.translator || !config.translator })
.option('iterations', '-i <iterations:posint>', { fallback: 1, hidden: () => config.maxIteration <= 1 })
.action(async ({ session, options }, input) => {
if (!input?.trim()) return session.execute('help novelai')
if (options.iterations && options.iterations > config.maxIteration) {
return session.text('.exceed-max-iteration', [config.maxIteration])
}
let imgUrl: string, image: ImageData
if (!restricted(session)) {
input = segment.transform(input, {
@ -209,13 +214,14 @@ export function apply(ctx: Context, config: Config) {
})
}
const id = Math.random().toString(36).slice(2)
const getRandomId = () => Math.random().toString(36).slice(2)
const iterations = Array(options.iterations).fill(0).map(getRandomId)
if (config.maxConcurrency) {
const store = tasks[session.cid] ||= new Set()
if (store.size >= config.maxConcurrency) {
return session.text('.concurrent-jobs')
} else {
store.add(id)
iterations.forEach((id) => store.add(id))
}
}
@ -223,8 +229,8 @@ export function apply(ctx: Context, config: Config) {
? session.text('.pending', [globalTasks.size])
: session.text('.waiting'))
globalTasks.add(id)
const cleanUp = () => {
iterations.forEach((id) => globalTasks.add(id))
const cleanUp = (id: string) => {
tasks[session.cid]?.delete(id)
globalTasks.delete(id)
}
@ -240,7 +246,7 @@ export function apply(ctx: Context, config: Config) {
}
})()
const data = (() => {
const getPayload = () => {
if (config.type !== 'sd-webui') {
parameters.sampler = sampler.sd2nai(options.sampler)
parameters.image = image?.base64 // NovelAI / NAIFU accepts bare base64 encoded image
@ -263,86 +269,97 @@ export function apply(ctx: Context, config: Config) {
denoising_strength: 'strength',
}),
}
})()
}
const request = () => ctx.http.axios(trimSlash(config.endpoint) + path, {
method: 'POST',
timeout: config.requestTimeout,
headers: {
...config.headers,
authorization: 'Bearer ' + token,
},
data,
}).then((res) => {
if (config.type === 'sd-webui') {
return stripDataPrefix((res.data as StableDiffusionWebUI.Response).images[0])
const iterate = async () => {
const request = () => ctx.http.axios(trimSlash(config.endpoint) + path, {
method: 'POST',
timeout: config.requestTimeout,
headers: {
...config.headers,
authorization: 'Bearer ' + token,
},
data: getPayload(),
}).then((res) => {
if (config.type === 'sd-webui') {
return stripDataPrefix((res.data as StableDiffusionWebUI.Response).images[0])
}
// event: newImage
// id: 1
// data:
return res.data?.slice(27)
})
let base64: string, count = 0
while (true) {
try {
base64 = await request()
break
} catch (err) {
if (Quester.isAxiosError(err)) {
if (err.code && err.code !== 'ETIMEDOUT' && ++count < config.maxRetryCount) {
continue
}
}
return await session.send(handleError(session, err))
}
}
// event: newImage
// id: 1
// data:
return res.data?.slice(27)
})
let base64: string, count = 0
while (true) {
try {
base64 = await request()
cleanUp()
break
} catch (err) {
if (Quester.isAxiosError(err)) {
if (err.code && err.code !== 'ETIMEDOUT' && ++count < config.maxRetryCount) {
continue
if (!base64.trim()) return await session.send(session.text('.empty-response'))
function getContent() {
if (config.output === 'minimal') return segment.image('base64://' + base64)
const attrs = {
userId: session.userId,
nickname: session.author?.nickname || session.username,
}
const result = segment('figure')
const lines = [`seed = ${parameters.seed}`]
if (config.output === 'verbose') {
if (!thirdParty()) {
lines.push(`model = ${model}`)
}
lines.push(
`sampler = ${options.sampler}`,
`steps = ${parameters.steps}`,
`scale = ${parameters.scale}`,
)
if (parameters.image) {
lines.push(
`strength = ${parameters.strength}`,
`noise = ${parameters.noise}`,
)
}
}
cleanUp()
return handleError(session, err)
result.children.push(segment('message', attrs, lines.join('\n')))
result.children.push(segment('message', attrs, `prompt = ${prompt}`))
if (config.output === 'verbose') {
result.children.push(segment('message', attrs, `undesired = ${uc}`))
}
result.children.push(segment('message', attrs, segment.image('base64://' + base64)))
return result
}
const messageIds = await session.send(getContent())
if (messageIds.length && config.recallTimeout) {
ctx.setTimeout(() => {
for (const id of messageIds) {
session.bot.deleteMessage(session.channelId, id)
}
}, config.recallTimeout)
}
}
if (!base64.trim()) return session.text('.empty-response')
function getContent() {
if (config.output === 'minimal') return segment.image('base64://' + base64)
const attrs = {
userId: session.userId,
nickname: session.author?.nickname || session.username,
while (iterations.length) {
try {
await iterate()
cleanUp(iterations.pop())
parameters.seed++
} catch (err) {
iterations.forEach(cleanUp)
throw err
}
const result = segment('figure')
const lines = [`seed = ${seed}`]
if (config.output === 'verbose') {
if (!thirdParty()) {
lines.push(`model = ${model}`)
}
lines.push(
`sampler = ${options.sampler}`,
`steps = ${parameters.steps}`,
`scale = ${parameters.scale}`,
)
if (parameters.image) {
lines.push(
`strength = ${parameters.strength}`,
`noise = ${parameters.noise}`,
)
}
}
result.children.push(segment('message', attrs, lines.join('\n')))
result.children.push(segment('message', attrs, `prompt = ${prompt}`))
if (config.output === 'verbose') {
result.children.push(segment('message', attrs, `undesired = ${uc}`))
}
result.children.push(segment('message', attrs, segment.image('base64://' + base64)))
return result
}
const ids = await session.send(getContent())
if (config.recallTimeout) {
ctx.setTimeout(() => {
for (const id of ids) {
session.bot.deleteMessage(session.channelId, id)
}
}, config.recallTimeout)
}
})

View File

@ -19,8 +19,10 @@ commands:
noise: 图片噪声强度
undesired: 排除标签
noTranslator: 禁用自动翻译
iterations: 设置绘制次数
messages:
exceed-max-iteration: 超过最大绘制次数。
expect-prompt: 请输入标签。
expect-image: 请输入图片。
latin-only: 只接受英文输入。