feat: unified support for text2img, img2img and enhance

This commit is contained in:
Shigma 2022-10-09 02:40:16 +08:00
parent f7ea85fd1a
commit 46cff6ecab
No known key found for this signature in database
GPG Key ID: 21C89B0B92907E14
3 changed files with 87 additions and 135 deletions

View File

@ -10,8 +10,7 @@
- [x] 自定义违禁词表
- [x] 发送一段时间后自动撤回
- [x] 连接到自建私服
- [ ] 图片增强功能
- [ ] img2img
- [x] img2img · 图片增强功能
得益于 Koishi 的插件化机制,只需配合其他插件即可实现更多功能:
@ -162,12 +161,12 @@ console.log(JSON.parse(localStorage.session).auth_token)
默认情况下是否过滤不良构图。
### baseTags
### basePrompt
- 类型: `string`
- 默认值: `''`
- 默认值: `'masterpiece, best quality'`
默认的附加标签。可以自定义一些常用的标签,例如 `best quality`, `masterpiece`
所有请求的附加标签。默认值相当于开启网页版的「Add Quality Tags」功能
### forbidden

View File

@ -1,7 +1,6 @@
import { Context, Dict, Logger, Quester, Schema, segment, Time } from 'koishi'
import { Context, Dict, Logger, Quester, Schema, Time, Session, segment } from 'koishi'
import { Context, Dict, Logger, Quester, Schema, segment, Session, Time } from 'koishi'
import { download } from './utils'
import getImageSize from 'image-size'
export const reactive = true
export const name = 'novelai'
@ -37,7 +36,8 @@ export interface Config {
orient?: Orient
sampler?: Sampler
anatomy?: boolean
baseTags?: string
allowAnlas?: boolean
basePrompt?: string
forbidden?: string
endpoint?: string
requestTimeout?: number
@ -51,25 +51,15 @@ export const Config: Schema<Config> = Schema.object({
orient: Schema.union(orients).description('默认的图片方向。').default('portrait'),
sampler: Schema.union(samplers).description('默认的采样器。').default('k_euler_ancestral'),
anatomy: Schema.boolean().default(true).description('是否过滤不合理构图。'),
baseTags: Schema.string().description('默认的附加标签。').default(''),
endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'),
allowAnlas: Schema.boolean().default(true).description('是否允许使用点数。'),
basePrompt: Schema.string().description('默认的附加标签。').default('masterpiece, best quality'),
forbidden: Schema.string().role('textarea').description('违禁词列表。含有违禁词的请求将被拒绝。').default(''),
endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'),
requestTimeout: Schema.number().role('time').description('当请求超过这个时间时会中止并提示超时。').default(Time.minute * 0.5),
recallTimeout: Schema.number().role('time').description('图片发送后自动撤回的时间 (设置为 0 以禁用此功能)。').default(0),
maxConcurrency: Schema.number().description('单个频道下的最大并发数量 (设置为 0 以禁用此功能)。').default(0),
})
function assembleMsgNode(user: {uin: string; name: string}, content: string | string[] | {}) {
return {
type: 'node',
data: {
uin: user.uin,
name: user.name,
content,
},
}
}
function errorHandler(session: Session, err: Error) {
if (Quester.isAxiosError(err)) {
if (err.response?.status === 429) {
@ -105,6 +95,7 @@ export function apply(ctx: Context, config: Config) {
const cmd = ctx.command('novelai <prompts:text>')
.shortcut('画画', { fuzzy: true })
.shortcut('约稿', { fuzzy: true })
.option('enhance', '-e')
.option('model', '-m <model>', { type: models })
.option('orient', '-o <orient>', { type: orients })
.option('sampler', '-s <sampler>', { type: samplers })
@ -113,6 +104,19 @@ export function apply(ctx: Context, config: Config) {
.option('anatomy', '-A', { value: false })
.action(async ({ session, options }, input) => {
if (!input?.trim()) return session.execute('help novelai')
let imgUrl: string
input = segment.transform(input, {
image(attrs) {
imgUrl = attrs.url
return ''
},
})
if (options.enhance && !imgUrl) {
return session.text('.expect-image')
}
input = input.toLowerCase().replace(/[,]/g, ', ').replace(/\s+/g, ' ')
if (/[^\s\w"'“”‘’.,:|\[\]\{\}-]/.test(input)) {
return session.text('.invalid-input')
@ -139,30 +143,54 @@ export function apply(ctx: Context, config: Config) {
if (options.anatomy ?? config.anatomy) undesired.push(badAnatomy)
const seed = options.seed || Math.round(new Date().getTime() / 1000)
session.send(session.text('.waiting'))
input += config.baseTags ? ', ' + config.baseTags : ''
input += config.basePrompt ? ', ' + config.basePrompt : ''
const parameters: Dict = {
seed,
n_samples: 1,
sampler: options.sampler,
uc: undesired,
ucPreset: 0,
}
if (imgUrl) {
const image = await download(ctx, imgUrl)
const size = getImageSize(image)
Object.assign(parameters, {
image: Buffer.from(image).toString('base64'),
scale: 11,
steps: 50,
})
if (options.enhance) {
Object.assign(parameters, {
height: size.height * 1.5,
width: size.width * 1.5,
noise: 0,
strength: 0.2,
})
}
} else {
Object.assign(parameters, {
scale: 12,
steps: 28,
})
}
if (!options.enhance) {
Object.assign(parameters, {
height: orient.height,
width: orient.width,
noise: 0.2,
strength: 0.7,
})
}
try {
const art = await ctx.http.axios(config.endpoint + '/ai/generate-image', {
method: 'POST',
timeout: config.requestTimeout,
headers: headers(config),
data: {
model,
input,
parameters: {
height: orient.height,
width: orient.width,
seed,
n_samples: 1,
noise: 0.2,
sampler: options.sampler,
scale: 12,
steps: 28,
strength: 0.7,
uc: undesired.join(', '),
ucPreset: 1,
},
},
data: { model, input, parameters },
}).then(res => {
return res.data.substr(27, res.data.length)
})
@ -173,7 +201,7 @@ export function apply(ctx: Context, config: Config) {
}
const ids = await session.send(segment('message', { forward: true }, [
segment('message', attrs, `seed = ${seed}`),
segment('message', attrs, input),
segment('message', attrs, `prompt = ${input}`),
segment('message', attrs, segment.image('base64://' + art)),
]))
if (config.recallTimeout) {
@ -189,78 +217,11 @@ export function apply(ctx: Context, config: Config) {
} finally {
states[session.cid]?.delete(id)
}
}
)
const enhance = ctx.guild().command('novelaiEnhance <img:text>')
.shortcut('增强', { fuzzy: true })
.option('model', '-m <model>', { type: models })
.option('sampler', '-s <sampler>', { type: samplers })
.option('undesired', '-u <undesired>', { type: undesiredContents})
.before(session => {
if (!session.args || segment.parse(session.args[0])[0].type !== 'image') return '需要传入图片'
})
.action(async ({ session, options, args }, input) => {
const id = Math.random().toString(36).slice(2)
if (config.maxConcurrency) {
states[session.cid] ||= new Set()
if (states[session.cid].size >= config.maxConcurrency) {
return session.text('.concurrent-jobs')
} else {
states[session.cid].add(id)
}
}
const model = modelMap[options.model]
const undesired = undesiredMap[options.undesired]
const seed = Math.round(new Date().getTime() / 1000)
const imgUrl = segment.parse(args[0])[0].attrs.url
const image = await readRemote(imgUrl, {})
const dim = getImgSize(image)
const b64Img = Buffer.from(image).toString('base64')
try {
const art = await ctx.http.axios('https://api.novelai.net/ai/generate-image', {
method: 'POST',
timeout: config.requestTimeout,
headers: headers(config),
data: {
model,
input: "masterpiece, best quality, girl",
parameters: {
height: dim.height * 1.5,
width: dim.width * 1.5,
image: b64Img,
seed,
n_samples: 1,
noise: 0,
sampler: options.sampler,
scale: 11,
steps: 50,
strength: 0.2,
uc: undesired,
ucPreset: 0,
},
},
}).then(res => {
return res.data.substr(27, res.data.length)
})
return segment.image('base64://' + art)
} catch (err) {
errorHandler(session, err)
return session.text('.unknown-error')
} finally {
states[session.cid]?.delete(id)
}
})
})
ctx.accept(['model', 'orient', 'sampler'], (config) => {
draw._options.model.fallback = config.model
draw._options.orient.fallback = config.orient
draw._options.sampler.fallback = config.sampler
draw._options.undesired.fallback = config.undesiredContents
enhance._options.model.fallback = config.model
enhance._options.sampler.fallback = config.sampler
enhance._options.undesired.fallback = config.undesiredContents
cmd._options.model.fallback = config.model
cmd._options.orient.fallback = config.orient
cmd._options.sampler.fallback = config.sampler
}, { immediate: true })
}

View File

@ -1,26 +1,18 @@
import axios from 'axios'
import sizeOf from 'image-size'
import { ISizeCalculationResult } from 'image-size/dist/types/interface'
import { Context } from 'koishi'
const MAX_CONTENT_SIZE = 10485760
const ALLOW_MIMETYPE = ['jpeg', 'png']
const ALLOWED_TYPES = ['jpeg', 'png']
async function readRemote(url: string, headers: {}): Promise<Buffer> {
const head = await axios.head(url, { headers })
if (parseInt(head.headers['content-length']) > MAX_CONTENT_SIZE) throw 'file too large'
else if (
ALLOW_MIMETYPE.every(t => { head.headers['content-type'].search(t) === -1 })
) throw 'unsupported file type'
return axios.get(url, {
responseType: 'arraybuffer',
headers
}).then(res => { return res.data })
export async function download(ctx: Context, url: string, headers = {}): Promise<Buffer> {
const head = await ctx.http.head(url, { headers })
if (+head.headers['content-length'] > MAX_CONTENT_SIZE) {
throw new Error('file too large')
}
if (ALLOWED_TYPES.every(t => head.headers['content-type'].includes(t))) {
throw new Error('unsupported file type')
}
return ctx.http.get(url, { responseType: 'arraybuffer', headers })
}
function getImgSize(img: Buffer): ISizeCalculationResult {
return sizeOf(img)
}
export { readRemote, getImgSize }