feat: support batch generation, fix #198 (#200)

This commit is contained in:
Maiko Sinkyaet Tan 2023-07-01 22:50:52 +08:00 committed by GitHub
parent b6ca19bb93
commit 0e0c3a2ab1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 9 deletions

View File

@ -47,6 +47,18 @@
约稿 -i 10 koishi
```
### 批量生成 (batch)
::: tip
此功能需要通过配置项 [`maxIterations`](./config.md#maxiterations) 手动开启。
:::
如果想要以一组输入批量生成图片,可以使用 `-b, --batch` 参数:
```text
约稿 -b 10 koishi
```
### 输出方式 (output)
此插件提供了三种不同的输出方式:`minimal` 表示只发送图片,`default` 表示发送图片和关键信息,`verbose` 表示发送全部信息。你可以使用 `-o, --output` 手动指定输出方式,或通过配置项修改默认的行为。

View File

@ -104,6 +104,7 @@ export function apply(ctx: Context, config: Config) {
.option('undesired', '-u <undesired>')
.option('noTranslator', '-T', { hidden: () => !ctx.translator || !config.translator })
.option('iterations', '-i <iterations:posint>', { fallback: 1, hidden: () => config.maxIterations <= 1 })
.option('batch', '-b <batch:option>', { fallback: 1, hidden: () => config.maxIterations <= 1 })
.action(async ({ session, options }, input) => {
if (!input?.trim()) return session.execute('help novelai')
@ -115,7 +116,9 @@ export function apply(ctx: Context, config: Config) {
return session.text('.custom-resolution-unsupported')
}
if (options.iterations && options.iterations > config.maxIterations) {
const { batch = 1, iterations = 1 } = options
const total = batch * iterations
if (total > config.maxIterations) {
return session.text('.exceed-max-iteration', [config.maxIterations])
}
@ -185,7 +188,7 @@ export function apply(ctx: Context, config: Config) {
const parameters: Dict = {
seed,
prompt,
n_samples: 1,
n_samples: options.batch,
uc,
// 0: low quality + bad anatomy
// 1: low quality
@ -245,13 +248,13 @@ export function apply(ctx: Context, config: Config) {
}
const getRandomId = () => Math.random().toString(36).slice(2)
const iterations = Array(options.iterations).fill(0).map(getRandomId)
const container = Array(iterations).fill(0).map(getRandomId)
if (config.maxConcurrency) {
const store = tasks[session.cid] ||= new Set()
if (store.size >= config.maxConcurrency) {
return session.text('.concurrent-jobs')
} else {
iterations.forEach((id) => store.add(id))
container.forEach((id) => store.add(id))
}
}
@ -259,7 +262,7 @@ export function apply(ctx: Context, config: Config) {
? session.text('.pending', [globalTasks.size])
: session.text('.waiting'))
iterations.forEach((id) => globalTasks.add(id))
container.forEach((id) => globalTasks.add(id))
const cleanUp = (id: string) => {
tasks[session.cid]?.delete(id)
globalTasks.delete(id)
@ -322,7 +325,7 @@ export function apply(ctx: Context, config: Config) {
karras: options.sampler.includes('_ka'),
hires_fix: options.hiresFix ?? config.hiresFix ?? false,
steps: parameters.steps,
n: 1,
n: parameters.n_samples,
},
nsfw: nsfw !== 'disallow',
trusted_workers: config.trustedWorkers,
@ -451,13 +454,13 @@ export function apply(ctx: Context, config: Config) {
}
}
while (iterations.length) {
while (container.length) {
try {
await iterate()
cleanUp(iterations.pop())
cleanUp(container.pop())
parameters.seed++
} catch (err) {
iterations.forEach(cleanUp)
container.forEach(cleanUp)
throw err
}
}

View File

@ -25,6 +25,7 @@ commands:
undesired: 排除标签
noTranslator: 禁用自动翻译
iterations: 设置绘制次数
batch: 设置绘制批次大小
messages:
exceed-max-iteration: 超过最大绘制次数。