mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-09 04:18:46 +08:00
tidy(ui): use object for addRegions graph builder util arg
This commit is contained in:
parent
5d8dd6e26e
commit
c276b60af9
@ -30,10 +30,25 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) =>
|
||||
return isEnabled && (hasTextPrompt || hasIPAdapter);
|
||||
};
|
||||
|
||||
type AddRegionsArg = {
|
||||
manager: CanvasManager;
|
||||
regions: CanvasRegionalGuidanceState[];
|
||||
g: Graph;
|
||||
bbox: Rect;
|
||||
base: BaseModelType;
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
|
||||
posCondCollect: Invocation<'collect'>;
|
||||
negCondCollect: Invocation<'collect'> | null;
|
||||
ipAdapterCollect: Invocation<'collect'>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds regional guidance to the graph
|
||||
* @param manager The canvas manager
|
||||
* @param regions Array of regions to add
|
||||
* @param g The graph to add the layers to
|
||||
* @param bbox The bounding box
|
||||
* @param base The base model type
|
||||
* @param posCond The positive conditioning node
|
||||
* @param negCond The negative conditioning node
|
||||
@ -43,18 +58,18 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) =>
|
||||
* @returns A promise that resolves to the regions that were successfully added to the graph
|
||||
*/
|
||||
|
||||
export const addRegions = async (
|
||||
manager: CanvasManager,
|
||||
regions: CanvasRegionalGuidanceState[],
|
||||
g: Graph,
|
||||
bbox: Rect,
|
||||
base: BaseModelType,
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>,
|
||||
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null,
|
||||
posCondCollect: Invocation<'collect'>,
|
||||
negCondCollect: Invocation<'collect'> | null,
|
||||
ipAdapterCollect: Invocation<'collect'>
|
||||
): Promise<AddedRegionResult[]> => {
|
||||
export const addRegions = async ({
|
||||
manager,
|
||||
regions,
|
||||
g,
|
||||
bbox,
|
||||
base,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollect,
|
||||
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
|
||||
const isSDXL = base === 'sdxl';
|
||||
const isFLUX = base === 'flux';
|
||||
|
||||
|
@ -213,31 +213,31 @@ export const buildFLUXGraph = async (
|
||||
g.deleteNode(controlNetCollector.id);
|
||||
}
|
||||
|
||||
const ipAdapterCollector = g.addNode({
|
||||
const ipAdapterCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collector'),
|
||||
});
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base);
|
||||
|
||||
const regionsResult = await addRegions(
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
canvas.regionalGuidance.entities,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
modelConfig.base,
|
||||
bbox: canvas.bbox.rect,
|
||||
base: modelConfig.base,
|
||||
posCond,
|
||||
null,
|
||||
negCond: null,
|
||||
posCondCollect,
|
||||
null,
|
||||
ipAdapterCollector
|
||||
);
|
||||
negCondCollect: null,
|
||||
ipAdapterCollect,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
|
||||
if (totalIPAdaptersAdded > 0) {
|
||||
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
|
||||
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
|
||||
} else {
|
||||
g.deleteNode(ipAdapterCollector.id);
|
||||
g.deleteNode(ipAdapterCollect.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -259,31 +259,31 @@ export const buildSD1Graph = async (
|
||||
g.deleteNode(t2iAdapterCollector.id);
|
||||
}
|
||||
|
||||
const ipAdapterCollector = g.addNode({
|
||||
const ipAdapterCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collector'),
|
||||
});
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base);
|
||||
|
||||
const regionsResult = await addRegions(
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
canvas.regionalGuidance.entities,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
modelConfig.base,
|
||||
bbox: canvas.bbox.rect,
|
||||
base: modelConfig.base,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollector
|
||||
);
|
||||
ipAdapterCollect,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
|
||||
if (totalIPAdaptersAdded > 0) {
|
||||
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
|
||||
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
|
||||
} else {
|
||||
g.deleteNode(ipAdapterCollector.id);
|
||||
g.deleteNode(ipAdapterCollect.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -264,31 +264,31 @@ export const buildSDXLGraph = async (
|
||||
g.deleteNode(t2iAdapterCollector.id);
|
||||
}
|
||||
|
||||
const ipAdapterCollector = g.addNode({
|
||||
const ipAdapterCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collector'),
|
||||
});
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base);
|
||||
|
||||
const regionsResult = await addRegions(
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
canvas.regionalGuidance.entities,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
modelConfig.base,
|
||||
bbox: canvas.bbox.rect,
|
||||
base: modelConfig.base,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollector
|
||||
);
|
||||
ipAdapterCollect,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
|
||||
if (totalIPAdaptersAdded > 0) {
|
||||
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
|
||||
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
|
||||
} else {
|
||||
g.deleteNode(ipAdapterCollector.id);
|
||||
g.deleteNode(ipAdapterCollect.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
Loading…
Reference in New Issue
Block a user