tidy(ui): use object for addRegions graph builder util arg

This commit is contained in:
psychedelicious 2024-11-28 14:26:24 +10:00
parent 5d8dd6e26e
commit c276b60af9
4 changed files with 59 additions and 44 deletions

View File

@ -30,10 +30,25 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) =>
return isEnabled && (hasTextPrompt || hasIPAdapter);
};
type AddRegionsArg = {
manager: CanvasManager;
regions: CanvasRegionalGuidanceState[];
g: Graph;
bbox: Rect;
base: BaseModelType;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
posCondCollect: Invocation<'collect'>;
negCondCollect: Invocation<'collect'> | null;
ipAdapterCollect: Invocation<'collect'>;
};
/**
* Adds regional guidance to the graph
* @param manager The canvas manager
* @param regions Array of regions to add
* @param g The graph to add the layers to
* @param bbox The bounding box
* @param base The base model type
* @param posCond The positive conditioning node
* @param negCond The negative conditioning node
@ -43,18 +58,18 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) =>
* @returns A promise that resolves to the regions that were successfully added to the graph
*/
export const addRegions = async (
manager: CanvasManager,
regions: CanvasRegionalGuidanceState[],
g: Graph,
bbox: Rect,
base: BaseModelType,
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>,
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null,
posCondCollect: Invocation<'collect'>,
negCondCollect: Invocation<'collect'> | null,
ipAdapterCollect: Invocation<'collect'>
): Promise<AddedRegionResult[]> => {
export const addRegions = async ({
manager,
regions,
g,
bbox,
base,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollect,
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
const isSDXL = base === 'sdxl';
const isFLUX = base === 'flux';

View File

@ -213,31 +213,31 @@ export const buildFLUXGraph = async (
g.deleteNode(controlNetCollector.id);
}
const ipAdapterCollector = g.addNode({
const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base);
const regionsResult = await addRegions(
const regionsResult = await addRegions({
manager,
canvas.regionalGuidance.entities,
regions: canvas.regionalGuidance.entities,
g,
canvas.bbox.rect,
modelConfig.base,
bbox: canvas.bbox.rect,
base: modelConfig.base,
posCond,
null,
negCond: null,
posCondCollect,
null,
ipAdapterCollector
);
negCondCollect: null,
ipAdapterCollect,
});
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
g.deleteNode(ipAdapterCollector.id);
g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {

View File

@ -259,31 +259,31 @@ export const buildSD1Graph = async (
g.deleteNode(t2iAdapterCollector.id);
}
const ipAdapterCollector = g.addNode({
const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base);
const regionsResult = await addRegions(
const regionsResult = await addRegions({
manager,
canvas.regionalGuidance.entities,
regions: canvas.regionalGuidance.entities,
g,
canvas.bbox.rect,
modelConfig.base,
bbox: canvas.bbox.rect,
base: modelConfig.base,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollector
);
ipAdapterCollect,
});
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
g.deleteNode(ipAdapterCollector.id);
g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {

View File

@ -264,31 +264,31 @@ export const buildSDXLGraph = async (
g.deleteNode(t2iAdapterCollector.id);
}
const ipAdapterCollector = g.addNode({
const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base);
const regionsResult = await addRegions(
const regionsResult = await addRegions({
manager,
canvas.regionalGuidance.entities,
regions: canvas.regionalGuidance.entities,
g,
canvas.bbox.rect,
modelConfig.base,
bbox: canvas.bbox.rect,
base: modelConfig.base,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollector
);
ipAdapterCollect,
});
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
g.deleteNode(ipAdapterCollector.id);
g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {