mirror of
https://github.com/koishijs/novelai-bot
synced 2025-01-08 11:17:32 +08:00
feat: unified support for text2img, img2img and enhance
This commit is contained in:
parent
f7ea85fd1a
commit
46cff6ecab
@ -10,8 +10,7 @@
|
||||
- [x] 自定义违禁词表
|
||||
- [x] 发送一段时间后自动撤回
|
||||
- [x] 连接到自建私服
|
||||
- [ ] 图片增强功能
|
||||
- [ ] img2img
|
||||
- [x] img2img · 图片增强功能
|
||||
|
||||
得益于 Koishi 的插件化机制,只需配合其他插件即可实现更多功能:
|
||||
|
||||
@ -162,12 +161,12 @@ console.log(JSON.parse(localStorage.session).auth_token)
|
||||
|
||||
默认情况下是否过滤不良构图。
|
||||
|
||||
### baseTags
|
||||
### basePrompt
|
||||
|
||||
- 类型: `string`
|
||||
- 默认值: `''`
|
||||
- 默认值: `'masterpiece, best quality'`
|
||||
|
||||
默认的附加标签。可以自定义一些常用的标签,例如 `best quality`, `masterpiece` 等。
|
||||
所有请求的附加标签。默认值相当于开启网页版的「Add Quality Tags」功能。
|
||||
|
||||
### forbidden
|
||||
|
||||
|
177
src/index.ts
177
src/index.ts
@ -1,7 +1,6 @@
|
||||
import { Context, Dict, Logger, Quester, Schema, segment, Time } from 'koishi'
|
||||
|
||||
import { Context, Dict, Logger, Quester, Schema, Time, Session, segment } from 'koishi'
|
||||
|
||||
import { Context, Dict, Logger, Quester, Schema, segment, Session, Time } from 'koishi'
|
||||
import { download } from './utils'
|
||||
import getImageSize from 'image-size'
|
||||
|
||||
export const reactive = true
|
||||
export const name = 'novelai'
|
||||
@ -37,7 +36,8 @@ export interface Config {
|
||||
orient?: Orient
|
||||
sampler?: Sampler
|
||||
anatomy?: boolean
|
||||
baseTags?: string
|
||||
allowAnlas?: boolean
|
||||
basePrompt?: string
|
||||
forbidden?: string
|
||||
endpoint?: string
|
||||
requestTimeout?: number
|
||||
@ -51,25 +51,15 @@ export const Config: Schema<Config> = Schema.object({
|
||||
orient: Schema.union(orients).description('默认的图片方向。').default('portrait'),
|
||||
sampler: Schema.union(samplers).description('默认的采样器。').default('k_euler_ancestral'),
|
||||
anatomy: Schema.boolean().default(true).description('是否过滤不合理构图。'),
|
||||
baseTags: Schema.string().description('默认的附加标签。').default(''),
|
||||
endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'),
|
||||
allowAnlas: Schema.boolean().default(true).description('是否允许使用点数。'),
|
||||
basePrompt: Schema.string().description('默认的附加标签。').default('masterpiece, best quality'),
|
||||
forbidden: Schema.string().role('textarea').description('违禁词列表。含有违禁词的请求将被拒绝。').default(''),
|
||||
endpoint: Schema.string().description('API 服务器地址。').default('https://api.novelai.net'),
|
||||
requestTimeout: Schema.number().role('time').description('当请求超过这个时间时会中止并提示超时。').default(Time.minute * 0.5),
|
||||
recallTimeout: Schema.number().role('time').description('图片发送后自动撤回的时间 (设置为 0 以禁用此功能)。').default(0),
|
||||
maxConcurrency: Schema.number().description('单个频道下的最大并发数量 (设置为 0 以禁用此功能)。').default(0),
|
||||
})
|
||||
|
||||
function assembleMsgNode(user: {uin: string; name: string}, content: string | string[] | {}) {
|
||||
return {
|
||||
type: 'node',
|
||||
data: {
|
||||
uin: user.uin,
|
||||
name: user.name,
|
||||
content,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
function errorHandler(session: Session, err: Error) {
|
||||
if (Quester.isAxiosError(err)) {
|
||||
if (err.response?.status === 429) {
|
||||
@ -105,6 +95,7 @@ export function apply(ctx: Context, config: Config) {
|
||||
const cmd = ctx.command('novelai <prompts:text>')
|
||||
.shortcut('画画', { fuzzy: true })
|
||||
.shortcut('约稿', { fuzzy: true })
|
||||
.option('enhance', '-e')
|
||||
.option('model', '-m <model>', { type: models })
|
||||
.option('orient', '-o <orient>', { type: orients })
|
||||
.option('sampler', '-s <sampler>', { type: samplers })
|
||||
@ -113,6 +104,19 @@ export function apply(ctx: Context, config: Config) {
|
||||
.option('anatomy', '-A', { value: false })
|
||||
.action(async ({ session, options }, input) => {
|
||||
if (!input?.trim()) return session.execute('help novelai')
|
||||
|
||||
let imgUrl: string
|
||||
input = segment.transform(input, {
|
||||
image(attrs) {
|
||||
imgUrl = attrs.url
|
||||
return ''
|
||||
},
|
||||
})
|
||||
|
||||
if (options.enhance && !imgUrl) {
|
||||
return session.text('.expect-image')
|
||||
}
|
||||
|
||||
input = input.toLowerCase().replace(/[,,]/g, ', ').replace(/\s+/g, ' ')
|
||||
if (/[^\s\w"'“”‘’.,:|\[\]\{\}-]/.test(input)) {
|
||||
return session.text('.invalid-input')
|
||||
@ -139,30 +143,54 @@ export function apply(ctx: Context, config: Config) {
|
||||
if (options.anatomy ?? config.anatomy) undesired.push(badAnatomy)
|
||||
const seed = options.seed || Math.round(new Date().getTime() / 1000)
|
||||
session.send(session.text('.waiting'))
|
||||
input += config.baseTags ? ', ' + config.baseTags : ''
|
||||
input += config.basePrompt ? ', ' + config.basePrompt : ''
|
||||
|
||||
const parameters: Dict = {
|
||||
seed,
|
||||
n_samples: 1,
|
||||
sampler: options.sampler,
|
||||
uc: undesired,
|
||||
ucPreset: 0,
|
||||
}
|
||||
|
||||
if (imgUrl) {
|
||||
const image = await download(ctx, imgUrl)
|
||||
const size = getImageSize(image)
|
||||
Object.assign(parameters, {
|
||||
image: Buffer.from(image).toString('base64'),
|
||||
scale: 11,
|
||||
steps: 50,
|
||||
})
|
||||
if (options.enhance) {
|
||||
Object.assign(parameters, {
|
||||
height: size.height * 1.5,
|
||||
width: size.width * 1.5,
|
||||
noise: 0,
|
||||
strength: 0.2,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
Object.assign(parameters, {
|
||||
scale: 12,
|
||||
steps: 28,
|
||||
})
|
||||
}
|
||||
|
||||
if (!options.enhance) {
|
||||
Object.assign(parameters, {
|
||||
height: orient.height,
|
||||
width: orient.width,
|
||||
noise: 0.2,
|
||||
strength: 0.7,
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
const art = await ctx.http.axios(config.endpoint + '/ai/generate-image', {
|
||||
method: 'POST',
|
||||
timeout: config.requestTimeout,
|
||||
headers: headers(config),
|
||||
data: {
|
||||
model,
|
||||
input,
|
||||
parameters: {
|
||||
height: orient.height,
|
||||
width: orient.width,
|
||||
seed,
|
||||
n_samples: 1,
|
||||
noise: 0.2,
|
||||
sampler: options.sampler,
|
||||
scale: 12,
|
||||
steps: 28,
|
||||
strength: 0.7,
|
||||
uc: undesired.join(', '),
|
||||
ucPreset: 1,
|
||||
},
|
||||
},
|
||||
data: { model, input, parameters },
|
||||
}).then(res => {
|
||||
return res.data.substr(27, res.data.length)
|
||||
})
|
||||
@ -173,7 +201,7 @@ export function apply(ctx: Context, config: Config) {
|
||||
}
|
||||
const ids = await session.send(segment('message', { forward: true }, [
|
||||
segment('message', attrs, `seed = ${seed}`),
|
||||
segment('message', attrs, input),
|
||||
segment('message', attrs, `prompt = ${input}`),
|
||||
segment('message', attrs, segment.image('base64://' + art)),
|
||||
]))
|
||||
if (config.recallTimeout) {
|
||||
@ -189,78 +217,11 @@ export function apply(ctx: Context, config: Config) {
|
||||
} finally {
|
||||
states[session.cid]?.delete(id)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
const enhance = ctx.guild().command('novelaiEnhance <img:text>')
|
||||
.shortcut('增强', { fuzzy: true })
|
||||
.option('model', '-m <model>', { type: models })
|
||||
.option('sampler', '-s <sampler>', { type: samplers })
|
||||
.option('undesired', '-u <undesired>', { type: undesiredContents})
|
||||
.before(session => {
|
||||
if (!session.args || segment.parse(session.args[0])[0].type !== 'image') return '需要传入图片'
|
||||
})
|
||||
.action(async ({ session, options, args }, input) => {
|
||||
const id = Math.random().toString(36).slice(2)
|
||||
if (config.maxConcurrency) {
|
||||
states[session.cid] ||= new Set()
|
||||
if (states[session.cid].size >= config.maxConcurrency) {
|
||||
return session.text('.concurrent-jobs')
|
||||
} else {
|
||||
states[session.cid].add(id)
|
||||
}
|
||||
}
|
||||
|
||||
const model = modelMap[options.model]
|
||||
const undesired = undesiredMap[options.undesired]
|
||||
const seed = Math.round(new Date().getTime() / 1000)
|
||||
const imgUrl = segment.parse(args[0])[0].attrs.url
|
||||
const image = await readRemote(imgUrl, {})
|
||||
const dim = getImgSize(image)
|
||||
const b64Img = Buffer.from(image).toString('base64')
|
||||
|
||||
try {
|
||||
const art = await ctx.http.axios('https://api.novelai.net/ai/generate-image', {
|
||||
method: 'POST',
|
||||
timeout: config.requestTimeout,
|
||||
headers: headers(config),
|
||||
data: {
|
||||
model,
|
||||
input: "masterpiece, best quality, girl",
|
||||
parameters: {
|
||||
height: dim.height * 1.5,
|
||||
width: dim.width * 1.5,
|
||||
image: b64Img,
|
||||
seed,
|
||||
n_samples: 1,
|
||||
noise: 0,
|
||||
sampler: options.sampler,
|
||||
scale: 11,
|
||||
steps: 50,
|
||||
strength: 0.2,
|
||||
uc: undesired,
|
||||
ucPreset: 0,
|
||||
},
|
||||
},
|
||||
}).then(res => {
|
||||
return res.data.substr(27, res.data.length)
|
||||
})
|
||||
return segment.image('base64://' + art)
|
||||
} catch (err) {
|
||||
errorHandler(session, err)
|
||||
return session.text('.unknown-error')
|
||||
} finally {
|
||||
states[session.cid]?.delete(id)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
ctx.accept(['model', 'orient', 'sampler'], (config) => {
|
||||
draw._options.model.fallback = config.model
|
||||
draw._options.orient.fallback = config.orient
|
||||
draw._options.sampler.fallback = config.sampler
|
||||
draw._options.undesired.fallback = config.undesiredContents
|
||||
enhance._options.model.fallback = config.model
|
||||
enhance._options.sampler.fallback = config.sampler
|
||||
enhance._options.undesired.fallback = config.undesiredContents
|
||||
cmd._options.model.fallback = config.model
|
||||
cmd._options.orient.fallback = config.orient
|
||||
cmd._options.sampler.fallback = config.sampler
|
||||
}, { immediate: true })
|
||||
}
|
||||
|
36
src/utils.ts
36
src/utils.ts
@ -1,26 +1,18 @@
|
||||
import axios from 'axios'
|
||||
import sizeOf from 'image-size'
|
||||
import { ISizeCalculationResult } from 'image-size/dist/types/interface'
|
||||
import { Context } from 'koishi'
|
||||
|
||||
const MAX_CONTENT_SIZE = 10485760
|
||||
const ALLOW_MIMETYPE = ['jpeg', 'png']
|
||||
const ALLOWED_TYPES = ['jpeg', 'png']
|
||||
|
||||
async function readRemote(url: string, headers: {}): Promise<Buffer> {
|
||||
const head = await axios.head(url, { headers })
|
||||
|
||||
if (parseInt(head.headers['content-length']) > MAX_CONTENT_SIZE) throw 'file too large'
|
||||
else if (
|
||||
ALLOW_MIMETYPE.every(t => { head.headers['content-type'].search(t) === -1 })
|
||||
) throw 'unsupported file type'
|
||||
|
||||
return axios.get(url, {
|
||||
responseType: 'arraybuffer',
|
||||
headers
|
||||
}).then(res => { return res.data })
|
||||
export async function download(ctx: Context, url: string, headers = {}): Promise<Buffer> {
|
||||
const head = await ctx.http.head(url, { headers })
|
||||
|
||||
if (+head.headers['content-length'] > MAX_CONTENT_SIZE) {
|
||||
throw new Error('file too large')
|
||||
}
|
||||
|
||||
if (ALLOWED_TYPES.every(t => head.headers['content-type'].includes(t))) {
|
||||
throw new Error('unsupported file type')
|
||||
}
|
||||
|
||||
return ctx.http.get(url, { responseType: 'arraybuffer', headers })
|
||||
}
|
||||
|
||||
function getImgSize(img: Buffer): ISizeCalculationResult {
|
||||
return sizeOf(img)
|
||||
}
|
||||
|
||||
export { readRemote, getImgSize }
|
||||
|
Loading…
Reference in New Issue
Block a user