Merge branch 'main' into ryan/flux-trajectory-guidance

This commit is contained in:
Ryan Dick 2024-09-20 22:29:34 +00:00
commit 183a67cb1e
45 changed files with 921 additions and 418 deletions

View File

@ -61,12 +61,14 @@ class Classification(str, Enum, metaclass=MetaEnum):
- `Beta`: The invocation is not yet stable, but is planned to be stable in the future. Workflows built around this invocation may break, but we are committed to supporting this invocation long-term.
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
- `Deprecated`: The invocation is deprecated and may be removed in a future version.
- `Internal`: The invocation is not intended for use by end-users. It may be changed or removed at any time, but is exposed for users to play with.
"""
Stable = "stable"
Beta = "beta"
Prototype = "prototype"
Deprecated = "deprecated"
Internal = "internal"
class UIConfigBase(BaseModel):

View File

@ -10,7 +10,6 @@ from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
@ -18,7 +17,6 @@ from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
OutputField,
WithBoard,
WithMetadata,
)
@ -1015,19 +1013,13 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(image_dto)
@invocation_output("canvas_v2_mask_and_crop_output")
class CanvasV2MaskAndCropOutput(ImageOutput):
offset_x: int = OutputField(description="The x offset of the image, after cropping")
offset_y: int = OutputField(description="The y offset of the image, after cropping")
@invocation(
"canvas_v2_mask_and_crop",
title="Canvas V2 Mask and Crop",
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
classification=Classification.Internal,
)
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Handles Canvas V2 image output masking and cropping"""
@ -1049,7 +1041,7 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
return ImageOps.invert(mask.convert("L"))
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
if self.source_image:
@ -1062,13 +1054,4 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
generated_image.putalpha(mask)
image_dto = context.images.save(image=generated_image)
# bbox = image.getbbox()
# image = image.crop(bbox)
return CanvasV2MaskAndCropOutput(
image=ImageField(image_name=image_dto.image_name),
offset_x=0,
offset_y=0,
width=image_dto.width,
height=image_dto.height,
)
return ImageOutput.build(image_dto)

Binary file not shown.

After

Width:  |  Height:  |  Size: 175 KiB

View File

@ -939,6 +939,7 @@
"zoomOutNodes": "Zoom Out",
"betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.",
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time.",
"internalDesc": "This invocation is used internally by Invoke. It may have breaking changes during app updates and may be removed at any time.",
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default",
@ -993,6 +994,7 @@
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
"noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation",
"canvasManagerNotLoaded": "Canvas Manager not loaded",
"fluxModelIncompatibleDimensions": "FLUX requires image dimension to be multiples of 16",
"canvasIsFiltering": "Canvas is filtering",
"canvasIsTransforming": "Canvas is transforming",
"canvasIsRasterizing": "Canvas is rasterizing",
@ -1541,6 +1543,12 @@
"Structure controls how closely the output image will keep to the layout of the original. Low structure allows major changes, while high structure strictly maintains the original composition and layout."
]
},
"fluxDevLicense": {
"heading": "Non-Commercial License",
"paragraphs": [
"FLUX.1 [dev] models are licensed under the FLUX [dev] non-commercial license. To use this model type for commercial purposes in Invoke, visit our website to learn more."
]
},
"optimizedDenoising": {
"heading": "Optimized Inpainting",
"paragraphs": [
@ -2081,6 +2089,10 @@
},
"showSendingToAlerts": "Alert When Sending to Different View"
},
"newUserExperience": {
"toGetStarted": "To get started, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
"gettingStartedSeries": "Want more guidance? Check out our <LinkComponent>Getting Started Series</LinkComponent> for tips on unlocking the full potential of the Invoke Studio."
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"canvasV2Announcement": {

View File

@ -4,7 +4,7 @@ import { useImageViewer } from 'features/gallery/components/ImageViewer/useImage
import { $isMenuOpen } from 'features/stylePresets/store/isMenuOpen';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useWorkflowLibraryModal } from 'features/workflowLibrary/store/isWorkflowLibraryModalOpen';
import { useCallback } from 'react';
import { useCallback, useState } from 'react';
export type StudioDestination =
| 'generation'
@ -16,22 +16,26 @@ export type StudioDestination =
export const useHandleStudioDestination = () => {
const dispatch = useAppDispatch();
const imageViewer = useImageViewer();
const { open: imageViewerOpen, close: imageViewerClose } = useImageViewer();
const [initialized, setInitialized] = useState(false);
const workflowLibraryModal = useWorkflowLibraryModal();
const handleStudioDestination = useCallback(
(destination: StudioDestination) => {
if (initialized) {
return;
}
switch (destination) {
case 'generation':
dispatch(setActiveTab('canvas'));
dispatch(settingsSendToCanvasChanged(false));
imageViewer.open();
imageViewerOpen();
break;
case 'canvas':
dispatch(setActiveTab('canvas'));
dispatch(settingsSendToCanvasChanged(true));
imageViewer.close();
imageViewerClose();
break;
case 'workflows':
dispatch(setActiveTab('workflows'));
@ -41,7 +45,7 @@ export const useHandleStudioDestination = () => {
break;
case 'viewAllWorkflows':
dispatch(setActiveTab('workflows'));
workflowLibraryModal.setFalse();
workflowLibraryModal.setTrue();
break;
case 'viewAllStylePresets':
dispatch(setActiveTab('canvas'));
@ -51,8 +55,9 @@ export const useHandleStudioDestination = () => {
dispatch(setActiveTab('canvas'));
break;
}
setInitialized(true);
},
[dispatch, imageViewer, workflowLibraryModal]
[dispatch, imageViewerOpen, imageViewerClose, workflowLibraryModal, initialized]
);
return handleStudioDestination;

View File

@ -29,6 +29,7 @@ import { OPEN_DELAY, POPOVER_DATA, POPPER_MODIFIERS } from './constants';
type Props = {
feature: Feature;
inPortal?: boolean;
hideDisable?: boolean;
children: ReactElement;
};
@ -37,48 +38,51 @@ const selectShouldEnableInformationalPopovers = createSelector(
(system) => system.shouldEnableInformationalPopovers
);
export const InformationalPopover = memo(({ feature, children, inPortal = true, ...rest }: Props) => {
const shouldEnableInformationalPopovers = useAppSelector(selectShouldEnableInformationalPopovers);
export const InformationalPopover = memo(
({ feature, children, inPortal = true, hideDisable = false, ...rest }: Props) => {
const shouldEnableInformationalPopovers = useAppSelector(selectShouldEnableInformationalPopovers);
const data = useMemo(() => POPOVER_DATA[feature], [feature]);
const data = useMemo(() => POPOVER_DATA[feature], [feature]);
const popoverProps = useMemo(() => merge(omit(data, ['image', 'href', 'buttonLabel']), rest), [data, rest]);
const popoverProps = useMemo(() => merge(omit(data, ['image', 'href', 'buttonLabel']), rest), [data, rest]);
if (!shouldEnableInformationalPopovers) {
return children;
if (!hideDisable && !shouldEnableInformationalPopovers) {
return children;
}
return (
<Popover
isLazy
closeOnBlur={false}
trigger="hover"
variant="informational"
openDelay={OPEN_DELAY}
modifiers={POPPER_MODIFIERS}
placement="top"
{...popoverProps}
>
<PopoverTrigger>{children}</PopoverTrigger>
{inPortal ? (
<Portal>
<Content data={data} feature={feature} hideDisable={hideDisable} />
</Portal>
) : (
<Content data={data} feature={feature} hideDisable={hideDisable} />
)}
</Popover>
);
}
return (
<Popover
isLazy
closeOnBlur={false}
trigger="hover"
variant="informational"
openDelay={OPEN_DELAY}
modifiers={POPPER_MODIFIERS}
placement="top"
{...popoverProps}
>
<PopoverTrigger>{children}</PopoverTrigger>
{inPortal ? (
<Portal>
<Content data={data} feature={feature} />
</Portal>
) : (
<Content data={data} feature={feature} />
)}
</Popover>
);
});
);
InformationalPopover.displayName = 'InformationalPopover';
type ContentProps = {
data?: PopoverData;
feature: Feature;
hideDisable: boolean;
};
const Content = ({ data, feature }: ContentProps) => {
const Content = ({ data, feature, hideDisable }: ContentProps) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const heading = useMemo<string | undefined>(() => t(`popovers.${feature}.heading`), [feature, t]);
@ -120,14 +124,7 @@ const Content = ({ data, feature }: ContentProps) => {
)}
{data?.image && (
<>
<Image
objectFit="contain"
maxW="60%"
maxH="60%"
backgroundColor="white"
src={data.image}
alt="Optional Image"
/>
<Image objectFit="contain" backgroundColor="white" src={data.image} alt="Optional Image" />
<Divider />
</>
)}
@ -137,9 +134,11 @@ const Content = ({ data, feature }: ContentProps) => {
<Divider />
<Flex alignItems="center" justifyContent="space-between" w="full">
<Button onClick={onClickDontShowMeThese} variant="link" size="sm">
{t('common.dontShowMeThese')}
</Button>
{!hideDisable && (
<Button onClick={onClickDontShowMeThese} variant="link" size="sm">
{t('common.dontShowMeThese')}
</Button>
)}
<Spacer />
{data?.href && (
<Button onClick={onClickLearnMore} leftIcon={<PiArrowSquareOutBold />} variant="link" size="sm">

View File

@ -1,4 +1,5 @@
import type { PopoverProps } from '@invoke-ai/ui-library';
import commercialLicenseBg from 'public/assets/images/commercial-license-bg.png';
export type Feature =
| 'clipSkip'
@ -59,7 +60,8 @@ export type Feature =
| 'scale'
| 'creativity'
| 'structure'
| 'optimizedDenoising';
| 'optimizedDenoising'
| 'fluxDevLicense';
export type PopoverData = PopoverProps & {
image?: string;
@ -187,6 +189,10 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = {
seamlessTilingYAxis: {
href: 'https://support.invoke.ai/support/solutions/articles/151000178161-advanced-settings',
},
fluxDevLicense: {
href: 'https://www.invoke.com/get-a-commercial-license-for-flux',
image: commercialLicenseBg,
},
} as const;
export const OPEN_DELAY = 1000; // in milliseconds

View File

@ -0,0 +1,13 @@
import type { IconProps } from '@invoke-ai/ui-library';
import { Icon } from '@invoke-ai/ui-library';
import { memo } from 'react';
export const InvokeLogoIcon = memo((props: IconProps) => {
return (
<Icon boxSize={8} opacity={1} stroke="base.500" viewBox="0 0 66 66" fill="none" {...props}>
<path d="M43.9137 16H63.1211V3H3.12109V16H22.3285L43.9137 50H63.1211V63H3.12109V50H22.3285" strokeWidth="5" />
</Icon>
);
});
InvokeLogoIcon.displayName = 'InvokeLogoIcon';

View File

@ -157,6 +157,9 @@ const createSelector = (
if (!params.fluxVAE) {
reasons.push({ content: i18n.t('parameters.invoke.noFLUXVAEModelSelected') });
}
if (bbox.rect.width % 16 !== 0 || bbox.rect.height % 16 !== 0) {
reasons.push({ content: i18n.t('parameters.invoke.fluxModelIncompatibleDimensions') });
}
}
canvas.controlLayers.entities

View File

@ -106,7 +106,6 @@ const FilterBox = memo(({ adapter }: { adapter: CanvasEntityAdapterRasterLayer |
variant="ghost"
leftIcon={<PiXBold />}
onClick={adapter.filterer.cancel}
isLoading={isProcessing}
loadingText={t('controlLayers.filter.cancel')}
>
{t('controlLayers.filter.cancel')}

View File

@ -14,7 +14,7 @@ import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import type { UploadOptions } from 'services/api/endpoints/images';
import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
import { getImageDTOSafe, uploadImage } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import stableHash from 'stable-hash';
import { assert } from 'tsafe';
@ -210,7 +210,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
if (cachedImageName) {
imageDTO = await getImageDTO(cachedImageName);
imageDTO = await getImageDTOSafe(cachedImageName);
if (imageDTO) {
this.log.trace({ rect, imageName: cachedImageName, imageDTO }, 'Using cached composite raster layer image');
return imageDTO;
@ -374,7 +374,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
if (cachedImageName) {
imageDTO = await getImageDTO(cachedImageName);
imageDTO = await getImageDTOSafe(cachedImageName);
if (imageDTO) {
this.log.trace({ rect, cachedImageName, imageDTO }, 'Using cached composite inpaint mask image');
return imageDTO;

View File

@ -1,5 +1,4 @@
import type { SerializableObject } from 'common/types';
import { withResultAsync } from 'common/util/result';
import { withResult, withResultAsync } from 'common/util/result';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@ -13,9 +12,9 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { debounce } from 'lodash-es';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { getImageDTO } from 'services/api/endpoints/images';
import { serializeError } from 'serialize-error';
import { buildSelectModelConfig } from 'services/api/hooks/modelsByType';
import { type BatchConfig, type ImageDTO, isControlNetOrT2IAdapterModelConfig, type S } from 'services/api/types';
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
type CanvasEntityFiltererConfig = {
@ -38,6 +37,11 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
subscriptions = new Set<() => void>();
config: CanvasEntityFiltererConfig = DEFAULT_CONFIG;
/**
* The AbortController used to cancel the filter processing.
*/
abortController: AbortController | null = null;
$isFiltering = atom<boolean>(false);
$hasProcessed = atom<boolean>(false);
$isProcessing = atom<boolean>(false);
@ -100,63 +104,82 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
processImmediate = async () => {
const config = this.$filterConfig.get();
const isValid = IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
const filterData = IMAGE_FILTERS[config.type];
// Cannot get TS to be happy with `config`, thinks it should be `never`... eh...
const isValid = filterData.validateConfig?.(config as never) ?? true;
if (!isValid) {
this.log.error({ config }, 'Invalid filter config');
return;
}
this.log.trace({ config }, 'Previewing filter');
this.log.trace({ config }, 'Processing filter');
const rect = this.parent.transformer.getRelativeRect();
const imageDTO = await this.parent.renderer.rasterize({ rect, attrs: { filters: [], opacity: 1 } });
const nodeId = getPrefixedId('filter_node');
const batch = this.buildBatchConfig(imageDTO, config, nodeId);
// Listen for the filter processing completion event
const completedListener = async (event: S['InvocationCompleteEvent']) => {
if (event.origin !== this.id || event.invocation_source_id !== nodeId) {
return;
}
this.manager.socket.off('invocation_complete', completedListener);
this.manager.socket.off('invocation_error', errorListener);
this.log.trace({ event } as SerializableObject, 'Handling filter processing completion');
const { result } = event;
assert(result.type === 'image_output', `Processor did not return an image output, got: ${result}`);
const imageDTO = await getImageDTO(result.image.image_name);
assert(imageDTO, "Failed to fetch processor output's image DTO");
this.imageState = imageDTOToImageObject(imageDTO);
await this.parent.bufferRenderer.setBuffer(this.imageState, true);
const rasterizeResult = await withResultAsync(() =>
this.parent.renderer.rasterize({ rect, attrs: { filters: [], opacity: 1 } })
);
if (rasterizeResult.isErr()) {
this.log.error({ error: serializeError(rasterizeResult.error) }, 'Error rasterizing entity');
this.$isProcessing.set(false);
this.$hasProcessed.set(true);
};
const errorListener = (event: S['InvocationErrorEvent']) => {
if (event.origin !== this.id || event.invocation_source_id !== nodeId) {
return;
}
this.manager.socket.off('invocation_complete', completedListener);
this.manager.socket.off('invocation_error', errorListener);
this.log.error({ event } as SerializableObject, 'Error processing filter');
this.$isProcessing.set(false);
};
this.manager.socket.on('invocation_complete', completedListener);
this.manager.socket.on('invocation_error', errorListener);
this.log.trace({ batch } as SerializableObject, 'Enqueuing filter batch');
return;
}
this.$isProcessing.set(true);
const req = this.manager.stateApi.enqueueBatch(batch);
const result = await withResultAsync(req.unwrap);
if (result.isErr()) {
const imageDTO = rasterizeResult.value;
// Cannot get TS to be happy with `config`, thinks it should be `never`... eh...
const buildGraphResult = withResult(() => filterData.buildGraph(imageDTO, config as never));
if (buildGraphResult.isErr()) {
this.log.error({ error: serializeError(buildGraphResult.error) }, 'Error building filter graph');
this.$isProcessing.set(false);
return;
}
req.reset();
const controller = new AbortController();
this.abortController = controller;
const { graph, outputNodeId } = buildGraphResult.value;
const filterResult = await withResultAsync(() =>
this.manager.stateApi.runGraphAndReturnImageOutput({
graph,
outputNodeId,
// The filter graph should always be prepended to the queue so it's processed ASAP.
prepend: true,
/**
* The filter node may need to download a large model. Currently, the models required by the filter nodes are
* downloaded just-in-time, as required by the filter. If we use a timeout here, we might get into a catch-22
* where the filter node is waiting for the model to download, but the download gets canceled if the filter
* node times out.
*
* (I suspect the model download will actually _not_ be canceled if the graph is canceled, but let's not chance it!)
*
* TODO(psyche): Figure out a better way to handle this. Probably need to download the models ahead of time.
*/
// timeout: 5000,
/**
* The filter node should be able to cancel the request if it's taking too long. This will cancel the graph's
* queue item and clear any event listeners on the request.
*/
signal: controller.signal,
})
);
if (filterResult.isErr()) {
this.log.error({ error: serializeError(filterResult.error) }, 'Error processing filter');
this.$isProcessing.set(false);
this.abortController = null;
return;
}
this.log.trace({ imageDTO: filterResult.value }, 'Filter processed');
this.imageState = imageDTOToImageObject(filterResult.value);
await this.parent.bufferRenderer.setBuffer(this.imageState, true);
this.$isProcessing.set(false);
this.$hasProcessed.set(true);
this.abortController = null;
};
process = debounce(this.processImmediate, this.config.processDebounceMs);
@ -188,6 +211,8 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
reset = () => {
this.log.trace('Resetting filter');
this.abortController?.abort();
this.abortController = null;
this.parent.bufferRenderer.clearBuffer();
this.parent.transformer.updatePosition();
this.parent.renderer.syncCache(true);
@ -205,31 +230,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.manager.stateApi.$filteringAdapter.set(null);
};
buildBatchConfig = (imageDTO: ImageDTO, config: FilterConfig, id: string): BatchConfig => {
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const node = IMAGE_FILTERS[config.type].buildNode(imageDTO, config as never);
node.id = id;
const batch: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[node.id]: {
...node,
// filtered images are always intermediate - do not save to gallery
is_intermediate: true,
},
},
edges: [],
},
origin: this.id,
runs: 1,
},
};
return batch;
};
repr = () => {
return {
id: this.id,

View File

@ -1,5 +1,6 @@
import { $authToken } from 'app/store/nanostores/authToken';
import { rgbColorToString } from 'common/util/colorCodeTransformers';
import { withResult } from 'common/util/result';
import { SyncableMap } from 'common/util/SyncableMap/SyncableMap';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@ -27,7 +28,7 @@ import { debounce } from 'lodash-es';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
import { getImageDTOSafe, uploadImage } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
@ -356,14 +357,25 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
};
/**
* Rasterizes the parent entity. If the entity has a rasterization cache for the given rect, the cached image is
* returned. Otherwise, the entity is rasterized and the image is uploaded to the server.
* Rasterizes the parent entity, returning a promise that resolves to the image DTO.
*
* If the entity has a rasterization cache for the given rect, the cached image is returned. Otherwise, the entity is
* rasterized and the image is uploaded to the server.
*
* The rasterization cache is reset when the entity's state changes. The buffer object is not considered part of the
* entity state for this purpose as it is a temporary object.
*
* @param rect The rect to rasterize. If omitted, the entity's full rect will be used.
* @returns A promise that resolves to the rasterized image DTO.
* If rasterization fails for any reason, the promise will reject.
*
* @param options The rasterization options.
* @param options.rect The region of the entity to rasterize.
* @param options.replaceObjects Whether to replace the entity's objects with the rasterized image. If you just want
* the entity's image, omit or set this to false.
* @param options.attrs The Konva node attributes to apply to the rasterized image group. For example, you might want
* to disable filters or set the opacity to the rasterized image.
* @param options.bg Draws the entity on a canvas with the given background color. If omitted, the entity is drawn on
* a transparent canvas.
* @returns A promise that resolves to the rasterized image DTO or rejects if rasterization fails.
*/
rasterize = async (options: {
rect: Rect;
@ -383,7 +395,7 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
if (cachedImageName) {
imageDTO = await getImageDTO(cachedImageName);
imageDTO = await getImageDTOSafe(cachedImageName);
if (imageDTO) {
this.log.trace({ rect, cachedImageName, imageDTO }, 'Using cached rasterized image');
return imageDTO;
@ -423,26 +435,38 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
if (this.parent.transformer.$isPendingRectCalculation.get()) {
return;
}
const pixelRect = this.parent.transformer.$pixelRect.get();
if (pixelRect.width === 0 || pixelRect.height === 0) {
return;
}
try {
// TODO(psyche): This is an internal Konva method, so it may break in the future. Can we make this API public?
const canvas = this.konva.objectGroup._getCachedSceneCanvas()._canvas as HTMLCanvasElement | undefined | null;
if (canvas) {
const nodeRect = this.parent.transformer.$nodeRect.get();
const rect = {
x: pixelRect.x - nodeRect.x,
y: pixelRect.y - nodeRect.y,
width: pixelRect.width,
height: pixelRect.height,
};
this.$canvasCache.set({ rect, canvas });
}
} catch (error) {
/**
* TODO(psyche): This is an internal Konva method, so it may break in the future. Can we make this API public?
*
* This method's API is unknown. It has been experimentally determined that it may throw, so we need to handle
* errors.
*/
const getCacheCanvasResult = withResult(
() => this.konva.objectGroup._getCachedSceneCanvas()._canvas as HTMLCanvasElement | undefined | null
);
if (getCacheCanvasResult.isErr()) {
// We are using an internal Konva method, so we need to catch any errors that may occur.
this.log.warn({ error: serializeError(error) }, 'Failed to update preview canvas');
this.log.warn({ error: serializeError(getCacheCanvasResult.error) }, 'Failed to update preview canvas');
return;
}
const canvas = getCacheCanvasResult.value;
if (canvas) {
const nodeRect = this.parent.transformer.$nodeRect.get();
const rect = {
x: pixelRect.x - nodeRect.x,
y: pixelRect.y - nodeRect.y,
width: pixelRect.width,
height: pixelRect.height,
};
this.$canvasCache.set({ rect, canvas });
}
}, 300);

View File

@ -1,3 +1,4 @@
import { withResultAsync } from 'common/util/result';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@ -15,6 +16,7 @@ import type { GroupConfig } from 'konva/lib/Group';
import { debounce, get } from 'lodash-es';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import { assert } from 'tsafe';
type CanvasEntityTransformerConfig = {
@ -575,7 +577,12 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.log.debug('Applying transform');
this.$isProcessing.set(true);
const rect = this.getRelativeRect();
await this.parent.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1, filters: [] } });
const rasterizeResult = await withResultAsync(() =>
this.parent.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1, filters: [] } })
);
if (rasterizeResult.isErr()) {
this.log.error({ error: serializeError(rasterizeResult.error) }, 'Failed to rasterize entity');
}
this.requestRectCalculation();
this.stopTransform();
};

View File

@ -11,7 +11,7 @@ import type { CanvasImageState } from 'features/controlLayers/store/types';
import { t } from 'i18next';
import Konva from 'konva';
import type { Logger } from 'roarr';
import { getImageDTO } from 'services/api/endpoints/images';
import { getImageDTOSafe } from 'services/api/endpoints/images';
export class CanvasObjectImage extends CanvasModuleBase {
readonly type = 'object_image';
@ -100,7 +100,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
this.konva.placeholder.text.text(t('common.loadingImage', 'Loading Image'));
}
const imageDTO = await getImageDTO(imageName);
const imageDTO = await getImageDTOSafe(imageName);
if (imageDTO === null) {
this.onFailedToLoadImage();
return;

View File

@ -2,6 +2,7 @@ import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Selector } from '@reduxjs/toolkit';
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
import type { AppStore, RootState } from 'app/store/store';
import { withResultAsync } from 'common/util/result';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@ -38,10 +39,13 @@ import type {
RgbaColor,
} from 'features/controlLayers/store/types';
import { RGBA_BLACK } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
import type { BatchConfig, ImageDTO, S } from 'services/api/types';
import { QueueError } from 'services/events/errors';
import { assert } from 'tsafe';
import type { CanvasEntityAdapter } from './CanvasEntity/types';
@ -187,14 +191,200 @@ export class CanvasStateApiModule extends CanvasModuleBase {
};
/**
* Enqueues a batch, pushing state to redux.
* Run a graph and return an image output. The specified output node must return an image output, else the promise
* will reject with an error.
*
* @param arg The arguments for the function.
* @param arg.graph The graph to execute.
* @param arg.outputNodeId The id of the node whose output will be retrieved.
* @param arg.destination The destination to assign to the batch. If omitted, the destination is not set.
* @param arg.prepend Whether to prepend the graph to the front of the queue. If omitted, the graph is appended to the end of the queue.
* @param arg.timeout The timeout for the batch. If omitted, there is no timeout.
* @param arg.signal An optional signal to cancel the operation. If omitted, the operation cannot be canceled!
*
* @returns A promise that resolves to the image output or rejects with an error.
*
* @example
*
* ```ts
* const graph = new Graph();
* const outputNode = graph.addNode({ id: 'my-resize-node', type: 'img_resize', image: { image_name: 'my-image.png' } });
* const controller = new AbortController();
* const imageDTO = await this.manager.stateApi.runGraphAndReturnImageOutput({
* graph,
* outputNodeId: outputNode.id,
* prepend: true,
* signal: controller.signal,
* });
* // To cancel the operation:
* controller.abort();
* ```
*/
enqueueBatch = (batch: BatchConfig) => {
return this.store.dispatch(
runGraphAndReturnImageOutput = async (arg: {
graph: Graph;
outputNodeId: string;
destination?: string;
prepend?: boolean;
timeout?: number;
signal?: AbortSignal;
}): Promise<ImageDTO> => {
const { graph, outputNodeId, destination, prepend, timeout, signal } = arg;
/**
* We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a
* race condition:
* - The queue item id is not available until the graph is enqueued
* - The graph may complete before we can set up the listeners to handle the completion event
*
* The origin is the only unique identifier we have that is guaranteed to be available before the graph is enqueued,
* so we will use that to filter events.
*/
const origin = getPrefixedId(graph.id);
const batch: BatchConfig = {
prepend,
batch: {
graph: graph.getGraph(),
origin,
destination,
runs: 1,
},
};
/**
* If a timeout is provided, we will cancel the graph if it takes too long - but we need a way to clear the timeout
* if the graph completes or errors before the timeout.
*/
let timeoutId: number | null = null;
const _clearTimeout = () => {
if (timeoutId !== null) {
window.clearTimeout(timeoutId);
timeoutId = null;
}
};
/**
* First, enqueue the graph - we need the `batch_id` to cancel the graph. But to get the `batch_id`, we need to
* `await` the request. You might be tempted to `await` the request inside the result promise, but we should not
* `await` inside a promise executor.
*
* See: https://eslint.org/docs/latest/rules/no-async-promise-executor
*/
const enqueueRequest = this.store.dispatch(
queueApi.endpoints.enqueueBatch.initiate(batch, {
// Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status
// updates.
fixedCacheKey: 'enqueueBatch',
// We do not need RTK to track this request in the store
track: false,
})
);
// The `batch_id` should _always_ be present - the OpenAPI schema from which the types are generated is incorrect.
// TODO(psyche): Fix the OpenAPI schema.
const { batch_id } = (await enqueueRequest.unwrap()).batch;
assert(batch_id, 'Enqueue result is missing batch_id');
const resultPromise = new Promise<ImageDTO>((resolve, reject) => {
const invocationCompleteHandler = async (event: S['InvocationCompleteEvent']) => {
// Ignore events that are not for this graph
if (event.origin !== origin) {
return;
}
// Ignore events that are not from the output node
if (event.invocation_source_id !== outputNodeId) {
return;
}
// If we get here, the event is for the correct graph and output node.
// Clear the timeout and socket listeners
_clearTimeout();
clearListeners();
// The result must be an image output
const { result } = event;
if (result.type !== 'image_output') {
reject(new Error(`Graph output node did not return an image output, got: ${result}`));
return;
}
// Get the result image DTO
const getImageDTOResult = await withResultAsync(() => getImageDTO(result.image.image_name));
if (getImageDTOResult.isErr()) {
reject(getImageDTOResult.error);
return;
}
// Ok!
resolve(getImageDTOResult.value);
};
const queueItemStatusChangedHandler = (event: S['QueueItemStatusChangedEvent']) => {
// Ignore events that are not for this graph
if (event.origin !== origin) {
return;
}
// Ignore events where the status is pending or in progress - no need to do anything for these
if (event.status === 'pending' || event.status === 'in_progress') {
return;
}
// event.status is 'failed', 'canceled' or 'completed' - something has gone awry
_clearTimeout();
clearListeners();
if (event.status === 'completed') {
// If we get a queue item completed event, that means we never got a completion event for the output node!
reject(new Error('Queue item completed without output node completion event'));
} else if (event.status === 'failed') {
// We expect the event to have error details, but technically it's possible that it doesn't
const { error_type, error_message, error_traceback } = event;
if (error_type && error_message && error_traceback) {
reject(new QueueError(error_type, error_message, error_traceback));
} else {
reject(new Error('Queue item failed, but no error details were provided'));
}
} else {
// event.status is 'canceled'
reject(new Error('Graph canceled'));
}
};
this.manager.socket.on('invocation_complete', invocationCompleteHandler);
this.manager.socket.on('queue_item_status_changed', queueItemStatusChangedHandler);
const clearListeners = () => {
this.manager.socket.off('invocation_complete', invocationCompleteHandler);
this.manager.socket.off('queue_item_status_changed', queueItemStatusChangedHandler);
};
const cancelGraph = () => {
this.store.dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batch_id] }, { track: false }));
};
if (timeout) {
timeoutId = window.setTimeout(() => {
this.log.trace('Graph canceled by timeout');
clearListeners();
cancelGraph();
reject(new Error('Graph timed out'));
}, timeout);
}
if (signal) {
signal.addEventListener('abort', () => {
this.log.trace('Graph canceled by signal');
_clearTimeout();
clearListeners();
cancelGraph();
reject(new Error('Graph canceled'));
});
}
});
return resultPromise;
};
/**

View File

@ -1,7 +1,8 @@
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { AnyInvocation, ControlNetModelConfig, Invocation, T2IAdapterModelConfig } from 'services/api/types';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { z } from 'zod';
@ -132,7 +133,10 @@ export const isFilterType = (v: unknown): v is FilterType => zFilterType.safePar
type ImageFilterData<T extends FilterConfig['type']> = {
type: T;
buildDefaults(): Extract<FilterConfig, { type: T }>;
buildNode(imageDTO: ImageWithDims, config: Extract<FilterConfig, { type: T }>): AnyInvocation;
buildGraph(
imageDTO: ImageWithDims,
config: Extract<FilterConfig, { type: T }>
): { graph: Graph; outputNodeId: string };
validateConfig?(config: Extract<FilterConfig, { type: T }>): boolean;
};
@ -144,13 +148,20 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
low_threshold: 100,
high_threshold: 200,
}),
buildNode: ({ image_name }, { low_threshold, high_threshold }): Invocation<'canny_edge_detection'> => ({
id: getPrefixedId('canny_edge_detection'),
type: 'canny_edge_detection',
image: { image_name },
low_threshold,
high_threshold,
}),
buildGraph: ({ image_name }, { low_threshold, high_threshold }) => {
const graph = new Graph(getPrefixedId('canny_edge_detection_filter'));
const node = graph.addNode({
id: getPrefixedId('canny_edge_detection'),
type: 'canny_edge_detection',
image: { image_name },
low_threshold,
high_threshold,
});
return {
graph,
outputNodeId: node.id,
};
},
},
color_map: {
type: 'color_map',
@ -158,12 +169,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
type: 'color_map',
tile_size: 64,
}),
buildNode: ({ image_name }, { tile_size }): Invocation<'color_map'> => ({
id: getPrefixedId('color_map'),
type: 'color_map',
image: { image_name },
tile_size,
}),
buildGraph: ({ image_name }, { tile_size }) => {
const graph = new Graph(getPrefixedId('color_map_filter'));
const node = graph.addNode({
id: getPrefixedId('color_map'),
type: 'color_map',
image: { image_name },
tile_size,
});
return {
graph,
outputNodeId: node.id,
};
},
},
content_shuffle: {
type: 'content_shuffle',
@ -171,12 +189,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
type: 'content_shuffle',
scale_factor: 256,
}),
buildNode: ({ image_name }, { scale_factor }): Invocation<'content_shuffle'> => ({
id: getPrefixedId('content_shuffle'),
type: 'content_shuffle',
image: { image_name },
scale_factor,
}),
buildGraph: ({ image_name }, { scale_factor }) => {
const graph = new Graph(getPrefixedId('content_shuffle_filter'));
const node = graph.addNode({
id: getPrefixedId('content_shuffle'),
type: 'content_shuffle',
image: { image_name },
scale_factor,
});
return {
graph,
outputNodeId: node.id,
};
},
},
depth_anything_depth_estimation: {
type: 'depth_anything_depth_estimation',
@ -184,12 +209,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
type: 'depth_anything_depth_estimation',
model_size: 'small_v2',
}),
buildNode: ({ image_name }, { model_size }): Invocation<'depth_anything_depth_estimation'> => ({
id: getPrefixedId('depth_anything_depth_estimation'),
type: 'depth_anything_depth_estimation',
image: { image_name },
model_size,
}),
buildGraph: ({ image_name }, { model_size }) => {
const graph = new Graph(getPrefixedId('depth_anything_depth_estimation_filter'));
const node = graph.addNode({
id: getPrefixedId('depth_anything_depth_estimation'),
type: 'depth_anything_depth_estimation',
image: { image_name },
model_size,
});
return {
graph,
outputNodeId: node.id,
};
},
},
hed_edge_detection: {
type: 'hed_edge_detection',
@ -197,23 +229,37 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
type: 'hed_edge_detection',
scribble: false,
}),
buildNode: ({ image_name }, { scribble }): Invocation<'hed_edge_detection'> => ({
id: getPrefixedId('hed_edge_detection'),
type: 'hed_edge_detection',
image: { image_name },
scribble,
}),
buildGraph: ({ image_name }, { scribble }) => {
const graph = new Graph(getPrefixedId('hed_edge_detection_filter'));
const node = graph.addNode({
id: getPrefixedId('hed_edge_detection'),
type: 'hed_edge_detection',
image: { image_name },
scribble,
});
return {
graph,
outputNodeId: node.id,
};
},
},
lineart_anime_edge_detection: {
type: 'lineart_anime_edge_detection',
buildDefaults: () => ({
type: 'lineart_anime_edge_detection',
}),
buildNode: ({ image_name }): Invocation<'lineart_anime_edge_detection'> => ({
id: getPrefixedId('lineart_anime_edge_detection'),
type: 'lineart_anime_edge_detection',
image: { image_name },
}),
buildGraph: ({ image_name }) => {
const graph = new Graph(getPrefixedId('lineart_anime_edge_detection_filter'));
const node = graph.addNode({
id: getPrefixedId('lineart_anime_edge_detection'),
type: 'lineart_anime_edge_detection',
image: { image_name },
});
return {
graph,
outputNodeId: node.id,
};
},
},
lineart_edge_detection: {
type: 'lineart_edge_detection',
@ -221,12 +267,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
type: 'lineart_edge_detection',
coarse: false,
}),
buildNode: ({ image_name }, { coarse }): Invocation<'lineart_edge_detection'> => ({
id: getPrefixedId('lineart_edge_detection'),
type: 'lineart_edge_detection',
image: { image_name },
coarse,
}),
buildGraph: ({ image_name }, { coarse }) => {
const graph = new Graph(getPrefixedId('lineart_edge_detection_filter'));
const node = graph.addNode({
id: getPrefixedId('lineart_edge_detection'),
type: 'lineart_edge_detection',
image: { image_name },
coarse,
});
return {
graph,
outputNodeId: node.id,
};
},
},
mediapipe_face_detection: {
type: 'mediapipe_face_detection',
@ -235,13 +288,20 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
max_faces: 1,
min_confidence: 0.5,
}),
buildNode: ({ image_name }, { max_faces, min_confidence }): Invocation<'mediapipe_face_detection'> => ({
id: getPrefixedId('mediapipe_face_detection'),
type: 'mediapipe_face_detection',
image: { image_name },
max_faces,
min_confidence,
}),
buildGraph: ({ image_name }, { max_faces, min_confidence }) => {
const graph = new Graph(getPrefixedId('mediapipe_face_detection_filter'));
const node = graph.addNode({
id: getPrefixedId('mediapipe_face_detection'),
type: 'mediapipe_face_detection',
image: { image_name },
max_faces,
min_confidence,
});
return {
graph,
outputNodeId: node.id,
};
},
},
mlsd_detection: {
type: 'mlsd_detection',
@ -250,24 +310,38 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
score_threshold: 0.1,
distance_threshold: 20.0,
}),
buildNode: ({ image_name }, { score_threshold, distance_threshold }): Invocation<'mlsd_detection'> => ({
id: getPrefixedId('mlsd_detection'),
type: 'mlsd_detection',
image: { image_name },
score_threshold,
distance_threshold,
}),
buildGraph: ({ image_name }, { score_threshold, distance_threshold }) => {
const graph = new Graph(getPrefixedId('mlsd_detection_filter'));
const node = graph.addNode({
id: getPrefixedId('mlsd_detection'),
type: 'mlsd_detection',
image: { image_name },
score_threshold,
distance_threshold,
});
return {
graph,
outputNodeId: node.id,
};
},
},
normal_map: {
type: 'normal_map',
buildDefaults: () => ({
type: 'normal_map',
}),
buildNode: ({ image_name }): Invocation<'normal_map'> => ({
id: getPrefixedId('normal_map'),
type: 'normal_map',
image: { image_name },
}),
buildGraph: ({ image_name }) => {
const graph = new Graph(getPrefixedId('normal_map_filter'));
const node = graph.addNode({
id: getPrefixedId('normal_map'),
type: 'normal_map',
image: { image_name },
});
return {
graph,
outputNodeId: node.id,
};
},
},
pidi_edge_detection: {
type: 'pidi_edge_detection',
@ -276,13 +350,20 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
quantize_edges: false,
scribble: false,
}),
buildNode: ({ image_name }, { quantize_edges, scribble }): Invocation<'pidi_edge_detection'> => ({
id: getPrefixedId('pidi_edge_detection'),
type: 'pidi_edge_detection',
image: { image_name },
quantize_edges,
scribble,
}),
buildGraph: ({ image_name }, { quantize_edges, scribble }) => {
const graph = new Graph(getPrefixedId('pidi_edge_detection_filter'));
const node = graph.addNode({
id: getPrefixedId('pidi_edge_detection'),
type: 'pidi_edge_detection',
image: { image_name },
quantize_edges,
scribble,
});
return {
graph,
outputNodeId: node.id,
};
},
},
dw_openpose_detection: {
type: 'dw_openpose_detection',
@ -292,14 +373,21 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
draw_face: true,
draw_hands: true,
}),
buildNode: ({ image_name }, { draw_body, draw_face, draw_hands }): Invocation<'dw_openpose_detection'> => ({
id: getPrefixedId('dw_openpose_detection'),
type: 'dw_openpose_detection',
image: { image_name },
draw_body,
draw_face,
draw_hands,
}),
buildGraph: ({ image_name }, { draw_body, draw_face, draw_hands }) => {
const graph = new Graph(getPrefixedId('dw_openpose_detection_filter'));
const node = graph.addNode({
id: getPrefixedId('dw_openpose_detection'),
type: 'dw_openpose_detection',
image: { image_name },
draw_body,
draw_face,
draw_hands,
});
return {
graph,
outputNodeId: node.id,
};
},
},
spandrel_filter: {
type: 'spandrel_filter',
@ -309,29 +397,30 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
autoScale: true,
scale: 1,
}),
buildNode: (
{ image_name },
{ model, scale, autoScale }
): Invocation<'spandrel_image_to_image' | 'spandrel_image_to_image_autoscale'> => {
buildGraph: ({ image_name }, { model, scale, autoScale }) => {
assert(model !== null);
if (autoScale) {
const node: Invocation<'spandrel_image_to_image_autoscale'> = {
id: getPrefixedId('spandrel_image_to_image_autoscale'),
type: 'spandrel_image_to_image_autoscale',
image_to_image_model: model,
image: { image_name },
scale,
};
return node;
} else {
const node: Invocation<'spandrel_image_to_image'> = {
id: getPrefixedId('spandrel_image_to_image'),
type: 'spandrel_image_to_image',
image_to_image_model: model,
image: { image_name },
};
return node;
}
const graph = new Graph(getPrefixedId('spandrel_filter'));
const node = graph.addNode(
autoScale
? {
id: getPrefixedId('spandrel_image_to_image_autoscale'),
type: 'spandrel_image_to_image_autoscale',
image_to_image_model: model,
image: { image_name },
scale,
}
: {
id: getPrefixedId('spandrel_image_to_image'),
type: 'spandrel_image_to_image',
image_to_image_model: model,
image: { image_name },
}
);
return {
graph,
outputNodeId: node.id,
};
},
validateConfig: (config): boolean => {
if (!config.model) {

View File

@ -7,7 +7,7 @@ import {
zParameterNegativePrompt,
zParameterPositivePrompt,
} from 'features/parameters/types/parameterSchemas';
import { getImageDTO } from 'services/api/endpoints/images';
import { getImageDTOSafe } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { z } from 'zod';
@ -31,7 +31,7 @@ const zImageWithDims = z
})
.refine(async (v) => {
const { image_name } = v;
const imageDTO = await getImageDTO(image_name, true);
const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true });
return imageDTO !== null;
});
export type ImageWithDims = z.infer<typeof zImageWithDims>;

View File

@ -1,4 +1,4 @@
import { Flex, Link, Text } from '@invoke-ai/ui-library';
import { Flex, Link, Spacer, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $projectName, $projectUrl } from 'app/store/nanostores/projectId';
import { memo } from 'react';
@ -9,7 +9,7 @@ export const GalleryHeader = memo(() => {
if (projectName && projectUrl) {
return (
<Flex gap={2} w="full" alignItems="center" justifyContent="space-evenly" pe={2}>
<Flex gap={2} alignItems="center" justifyContent="space-evenly" pe={2} w="50%">
<Text fontSize="md" fontWeight="semibold" noOfLines={1} wordBreak="break-all" w="full" textAlign="center">
<Link href={projectUrl}>{projectName}</Link>
</Text>
@ -17,7 +17,7 @@ export const GalleryHeader = memo(() => {
);
}
return null;
return <Spacer />;
});
GalleryHeader.displayName = 'GalleryHeader';

View File

@ -1,4 +1,4 @@
import { Box, Button, Collapse, Divider, Flex, IconButton, Spacer, useDisclosure } from '@invoke-ai/ui-library';
import { Box, Button, Collapse, Divider, Flex, IconButton, useDisclosure } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useScopeOnFocus } from 'common/hooks/interactionScopes';
import { GalleryHeader } from 'features/gallery/components/GalleryHeader';
@ -52,18 +52,19 @@ const GalleryPanelContent = () => {
return (
<Flex ref={ref} position="relative" flexDirection="column" h="full" w="full" tabIndex={-1}>
<GalleryHeader />
<Flex alignItems="center" w="full">
<Button
size="sm"
variant="ghost"
onClick={boardsListPanel.toggle}
rightIcon={boardsListPanel.isCollapsed ? <PiCaretDownBold /> : <PiCaretUpBold />}
>
{boardsListPanel.isCollapsed ? t('boards.viewBoards') : t('boards.hideBoards')}
</Button>
<Spacer />
<Flex h="full">
<Flex w="25%">
<Button
size="sm"
variant="ghost"
onClick={boardsListPanel.toggle}
rightIcon={boardsListPanel.isCollapsed ? <PiCaretDownBold /> : <PiCaretUpBold />}
>
{boardsListPanel.isCollapsed ? t('boards.viewBoards') : t('boards.hideBoards')}
</Button>
</Flex>
<GalleryHeader />
<Flex h="full" w="25%" justifyContent="flex-end">
<GallerySettingsPopover />
<IconButton
size="sm"

View File

@ -3,7 +3,6 @@ import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import type { TypesafeDraggableData } from 'features/dnd/types';
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
@ -12,15 +11,13 @@ import { selectShouldShowImageDetails, selectShouldShowProgressInViewer } from '
import type { AnimationProps } from 'framer-motion';
import { AnimatePresence, motion } from 'framer-motion';
import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImageBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { $hasProgress, $isProgressFromCanvas } from 'services/events/stores';
import { NoContentForViewer } from './NoContentForViewer';
import ProgressImage from './ProgressImage';
const CurrentImagePreview = () => {
const { t } = useTranslation();
const shouldShowImageDetails = useAppSelector(selectShouldShowImageDetails);
const imageName = useAppSelector(selectLastSelectedImageName);
const hasDenoiseProgress = useStore($hasProgress);
@ -72,7 +69,7 @@ const CurrentImagePreview = () => {
isUploadDisabled={true}
fitContainer
useThumbailFallback
noContentFallback={<IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />}
noContentFallback={<NoContentForViewer />}
dataTestId="image-preview"
/>
)}

View File

@ -0,0 +1,59 @@
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { InvokeLogoIcon } from 'common/components/InvokeLogoIcon';
import { LOADING_SYMBOL, useHasImages } from 'features/gallery/hooks/useHasImages';
import { Trans, useTranslation } from 'react-i18next';
import { PiImageBold } from 'react-icons/pi';
export const NoContentForViewer = () => {
const hasImages = useHasImages();
const { t } = useTranslation();
if (hasImages === LOADING_SYMBOL) {
return (
// Blank bg w/ a spinner. The new user experience components below have an invoke logo, but it's not centered.
// If we show the logo while loading, there is an awkward layout shift where the invoke logo moves a bit. Less
// jarring to show a blank bg with a spinner - it will only be shown for a moment as we do the initial images
// fetching.
<Flex position="relative" width="full" height="full" alignItems="center" justifyContent="center">
<Spinner label="Loading" color="grey" position="absolute" size="sm" width={8} height={8} right={4} bottom={4} />
</Flex>
);
}
if (hasImages) {
return <IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />;
}
return (
<Flex flexDir="column" gap={4} alignItems="center" textAlign="center" maxW="600px">
<InvokeLogoIcon w={40} h={40} />
<Text fontSize="md" color="base.200" pt={16}>
<Trans
i18nKey="newUserExperience.toGetStarted"
components={{
StrongComponent: <Text as="span" color="white" fontSize="md" fontWeight="semibold" />,
}}
/>
</Text>
<Text fontSize="md" color="base.200">
<Trans
i18nKey="newUserExperience.gettingStartedSeries"
components={{
LinkComponent: (
<Text
as="a"
color="white"
fontSize="md"
fontWeight="semibold"
href="https://www.youtube.com/@invokeai/videos"
target="_blank"
/>
),
}}
/>
</Text>
</Flex>
);
};

View File

@ -0,0 +1,33 @@
import { useMemo } from 'react';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { useListImagesQuery } from 'services/api/endpoints/images';
export const LOADING_SYMBOL = Symbol('LOADING');
export const useHasImages = () => {
const { data: boardList, isLoading: loadingBoards } = useListAllBoardsQuery({ include_archived: true });
const { data: uncategorizedImages, isLoading: loadingImages } = useListImagesQuery({
board_id: 'none',
offset: 0,
limit: 0,
is_intermediate: false,
});
const hasImages = useMemo(() => {
// default to true
if (loadingBoards || loadingImages) {
return LOADING_SYMBOL;
}
const hasBoards = boardList && boardList.length > 0;
if (hasBoards) {
if (boardList.filter((board) => board.image_count > 0).length > 0) {
return true;
}
}
return uncategorizedImages ? uncategorizedImages.total > 0 : true;
}, [boardList, uncategorizedImages, loadingBoards, loadingImages]);
return hasImages;
};

View File

@ -67,7 +67,7 @@ import {
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import { get, isArray, isString } from 'lodash-es';
import { getImageDTO } from 'services/api/endpoints/images';
import { getImageDTOSafe } from 'services/api/endpoints/images';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
@ -603,7 +603,7 @@ const parseIPAdapterToIPAdapterLayer: MetadataParseFunc<CanvasReferenceImageStat
begin_step_percent ?? initialIPAdapter.beginEndStepPct[0],
end_step_percent ?? initialIPAdapter.beginEndStepPct[1],
];
const imageDTO = image ? await getImageDTO(image.image_name) : null;
const imageDTO = image ? await getImageDTOSafe(image.image_name) : null;
const layer: CanvasReferenceImageState = {
id: getPrefixedId('ip_adapter'),

View File

@ -40,7 +40,7 @@ import { computed } from 'nanostores';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFlaskBold, PiHammerBold } from 'react-icons/pi';
import { PiCircuitryBold, PiFlaskBold, PiHammerBold } from 'react-icons/pi';
import type { EdgeChange, NodeChange } from 'reactflow';
import type { S } from 'services/api/types';
@ -413,6 +413,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
<Flex alignItems="center" gap={2}>
{item.classification === 'beta' && <Icon boxSize={4} color="invokeYellow.300" as={PiHammerBold} />}
{item.classification === 'prototype' && <Icon boxSize={4} color="invokeRed.300" as={PiFlaskBold} />}
{item.classification === 'internal' && <Icon boxSize={4} color="invokePurple.300" as={PiCircuitryBold} />}
<Text fontWeight="semibold">{item.label}</Text>
<Spacer />
<Text variant="subtext" fontWeight="semibold">

View File

@ -3,7 +3,7 @@ import { useNodeClassification } from 'features/nodes/hooks/useNodeClassificatio
import type { Classification } from 'features/nodes/types/common';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFlaskBold, PiHammerBold } from 'react-icons/pi';
import { PiCircuitryBold, PiFlaskBold, PiHammerBold } from 'react-icons/pi';
interface Props {
nodeId: string;
@ -22,7 +22,7 @@ const InvocationNodeClassificationIcon = ({ nodeId }: Props) => {
placement="top"
shouldWrapChildren
>
<Icon as={getIcon(classification)} display="block" boxSize={4} color="base.400" />
<ClassificationIcon classification={classification} />
</Tooltip>
);
};
@ -40,19 +40,27 @@ const ClassificationTooltipContent = memo(({ classification }: { classification:
return t('nodes.prototypeDesc');
}
if (classification === 'internal') {
return t('nodes.prototypeDesc');
}
return null;
});
ClassificationTooltipContent.displayName = 'ClassificationTooltipContent';
const getIcon = (classification: Classification) => {
const ClassificationIcon = ({ classification }: { classification: Classification }) => {
if (classification === 'beta') {
return PiHammerBold;
return <Icon as={PiHammerBold} display="block" boxSize={4} color="invokeYellow.300" />;
}
if (classification === 'prototype') {
return PiFlaskBold;
return <Icon as={PiFlaskBold} display="block" boxSize={4} color="invokeRed.300" />;
}
return undefined;
if (classification === 'internal') {
return <Icon as={PiCircuitryBold} display="block" boxSize={4} color="invokePurple.300" />;
}
return null;
};

View File

@ -22,7 +22,7 @@ export const zColorField = z.object({
});
export type ColorField = z.infer<typeof zColorField>;
export const zClassification = z.enum(['stable', 'beta', 'prototype', 'deprecated']);
export const zClassification = z.enum(['stable', 'beta', 'prototype', 'deprecated', 'internal']);
export type Classification = z.infer<typeof zClassification>;
export const zSchedulerField = z.enum([

View File

@ -2,7 +2,6 @@ import { deepClone } from 'common/util/deepClone';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { AnyInvocation, Invocation } from 'services/api/types';
import { assert, AssertionError, is } from 'tsafe';
import { validate } from 'uuid';
import { describe, expect, it } from 'vitest';
import { z } from 'zod';
@ -11,11 +10,12 @@ describe('Graph', () => {
it('should create a new graph with the correct id', () => {
const g = new Graph('test-id');
expect(g._graph.id).toBe('test-id');
expect(g.id).toBe('test-id');
});
it('should create a new graph with a uuid id if none is provided', () => {
it('should create an id if none is provided', () => {
const g = new Graph();
expect(g._graph.id).not.toBeUndefined();
expect(validate(g._graph.id)).toBeTruthy();
expect(g.id).not.toBeUndefined();
});
});

View File

@ -32,10 +32,12 @@ export type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edge
export class Graph {
_graph: GraphType;
_metadataNodeId = getPrefixedId('core_metadata');
id: string;
constructor(id?: string) {
this.id = id ?? Graph.getId('graph');
this._graph = {
id: id ?? uuidv4(),
id: this.id,
nodes: {},
edges: [],
};

View File

@ -1,3 +1,5 @@
import { logger } from 'app/logging/logger';
import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type {
CanvasControlLayerState,
@ -6,9 +8,12 @@ import type {
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { serializeError } from 'serialize-error';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
type AddControlNetsResult = {
addedControlNets: number;
};
@ -33,9 +38,17 @@ export const addControlNets = async (
for (const layer of validControlLayers) {
result.addedControlNets++;
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' });
const getImageDTOResult = await withResultAsync(() => {
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
return adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' });
});
if (getImageDTOResult.isErr()) {
log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing control layer');
continue;
}
const imageDTO = getImageDTOResult.value;
addControlNetToGraph(g, layer, imageDTO, collector);
}
@ -66,9 +79,17 @@ export const addT2IAdapters = async (
for (const layer of validControlLayers) {
result.addedT2IAdapters++;
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [], bg: 'black' } });
const getImageDTOResult = await withResultAsync(() => {
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
return adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' });
});
if (getImageDTOResult.isErr()) {
log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing control layer');
continue;
}
const imageDTO = getImageDTOResult.value;
addT2IAdapterToGraph(g, layer, imageDTO, collector);
}

View File

@ -1,4 +1,6 @@
import { logger } from 'app/logging/logger';
import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
@ -8,9 +10,12 @@ import type {
RegionalGuidanceReferenceImageState,
} from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { serializeError } from 'serialize-error';
import type { BaseModelType, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
type AddedRegionResult = {
addedPositivePrompt: boolean;
addedNegativePrompt: boolean;
@ -64,9 +69,18 @@ export const addRegions = async (
addedAutoNegativePositivePrompt: false,
addedIPAdapters: 0,
};
const adapter = manager.adapters.regionMasks.get(region.id);
assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize({ rect: bbox });
const getImageDTOResult = await withResultAsync(() => {
const adapter = manager.adapters.regionMasks.get(region.id);
assert(adapter, 'Adapter not found');
return adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } });
});
if (getImageDTOResult.isErr()) {
log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing region mask');
continue;
}
const imageDTO = getImageDTOResult.value;
// The main mask-to-tensor node
const maskToTensor = g.addNode({

View File

@ -113,8 +113,8 @@ export const buildFLUXGraph = async (
g.upsertMetadata({
generation_mode: 'flux_txt2img',
guidance,
width: scaledSize.width,
height: scaledSize.height,
width: originalSize.width,
height: originalSize.height,
positive_prompt: positivePrompt,
model: Graph.getModelMetadataField(modelConfig),
seed,

View File

@ -142,8 +142,8 @@ export const buildSD1Graph = async (
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
width: scaledSize.width,
height: scaledSize.height,
width: originalSize.width,
height: originalSize.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: Graph.getModelMetadataField(modelConfig),

View File

@ -141,8 +141,8 @@ export const buildSDXLGraph = async (
generation_mode: 'sdxl_txt2img',
cfg_scale,
cfg_rescale_multiplier,
width: scaledSize.width,
height: scaledSize.height,
width: originalSize.width,
height: originalSize.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: Graph.getModelMetadataField(modelConfig),

View File

@ -1,23 +1,41 @@
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
import { Box, Combobox, Flex, FormControl, FormLabel, Icon, Spacer, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectModelKey } from 'features/controlLayers/store/paramsSlice';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { modelSelected } from 'features/parameters/store/actions';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { MdMoneyOff } from 'react-icons/md';
import { useMainModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, MainModelConfig } from 'services/api/types';
import { type AnyModelConfig, isCheckpointMainModelConfig, type MainModelConfig } from 'services/api/types';
const ParamMainModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const activeTabName = useAppSelector(selectActiveTab);
const selectedModel = useAppSelector(selectModel);
const selectedModelKey = useAppSelector(selectModelKey);
// const selectedModel = useAppSelector(selectModel);
const [modelConfigs, { isLoading }] = useMainModels();
const selectedModel = useMemo(() => {
if (!modelConfigs) {
return null;
}
if (selectedModelKey === null) {
return null;
}
const modelConfig = modelConfigs.find((model) => model.key === selectedModelKey);
if (!modelConfig) {
return null;
}
return modelConfig;
}, [modelConfigs, selectedModelKey]);
const tooltipLabel = useMemo(() => {
if (!modelConfigs.length || !selectedModel) {
return;
@ -54,11 +72,26 @@ const ParamMainModelSelect = () => {
getIsDisabled,
});
const isFluxDevSelected = useMemo(() => {
return selectedModel && isCheckpointMainModelConfig(selectedModel) && selectedModel.config_path === 'flux-dev';
}, [selectedModel]);
return (
<FormControl isDisabled={!modelConfigs.length} isInvalid={!value || !modelConfigs.length}>
<InformationalPopover feature="paramModel">
<FormLabel>{t('modelManager.model')}</FormLabel>
</InformationalPopover>
<Flex>
<InformationalPopover feature="paramModel">
<FormLabel>{t('modelManager.model')}</FormLabel>
</InformationalPopover>
{isFluxDevSelected ? (
<InformationalPopover feature="fluxDevLicense" hideDisable={true}>
<Flex justifyContent="flex-start">
<Icon as={MdMoneyOff} />
</Flex>
</InformationalPopover>
) : (
<Spacer />
)}
</Flex>
<Tooltip label={tooltipLabel}>
<Box w="full" minW={0}>
<Combobox

View File

@ -4,7 +4,7 @@ import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsFLUX, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { selectCanvasSlice, selectScaleMethod } from 'features/controlLayers/store/selectors';
import { ParamOptimizedDenoisingToggle } from 'features/parameters/components/Advanced/ParamOptimizedDenoisingToggle';
import BboxScaledHeight from 'features/parameters/components/Bbox/BboxScaledHeight';
import BboxScaledWidth from 'features/parameters/components/Bbox/BboxScaledWidth';
@ -51,6 +51,7 @@ const scalingLabelProps: FormLabelProps = {
export const ImageSettingsAccordion = memo(() => {
const { t } = useTranslation();
const badges = useAppSelector(selectBadges);
const scaleMethod = useAppSelector(selectScaleMethod);
const { isOpen: isOpenAccordion, onToggle: onToggleAccordion } = useStandaloneAccordionToggle({
id: 'image-settings',
defaultIsOpen: true,
@ -80,10 +81,12 @@ export const ImageSettingsAccordion = memo(() => {
<Flex gap={4} pb={4} flexDir="column">
{isFLUX && <ParamOptimizedDenoisingToggle />}
<BboxScaleMethod />
<FormControlGroup formLabelProps={scalingLabelProps}>
<BboxScaledWidth />
<BboxScaledHeight />
</FormControlGroup>
{scaleMethod !== 'none' && (
<FormControlGroup formLabelProps={scalingLabelProps}>
<BboxScaledWidth />
<BboxScaledHeight />
</FormControlGroup>
)}
</Flex>
</Expander>
</Flex>

View File

@ -1,18 +1,18 @@
import {
Box,
Flex,
IconButton,
Image,
Popover,
PopoverArrow,
PopoverBody,
PopoverCloseButton,
PopoverContent,
PopoverHeader,
PopoverTrigger,
Text,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { shouldShowNotificationIndicatorChanged } from 'features/ui/store/uiSlice';
import { shouldShowNotificationChanged } from 'features/ui/store/uiSlice';
import InvokeSymbol from 'public/assets/images/invoke-favicon.png';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -24,9 +24,9 @@ import { CanvasV2Announcement } from './CanvasV2Announcement';
export const Notifications = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const shouldShowNotificationIndicator = useAppSelector((s) => s.ui.shouldShowNotificationIndicator);
const shouldShowNotification = useAppSelector((s) => s.ui.shouldShowNotification);
const resetIndicator = useCallback(() => {
dispatch(shouldShowNotificationIndicatorChanged(false));
dispatch(shouldShowNotificationChanged(false));
}, [dispatch]);
const { data } = useGetAppVersionQuery();
@ -35,7 +35,7 @@ export const Notifications = () => {
}
return (
<Popover onOpen={resetIndicator} placement="top-start">
<Popover onClose={resetIndicator} placement="top-start" autoFocus={false} defaultIsOpen={shouldShowNotification}>
<PopoverTrigger>
<Flex pos="relative">
<IconButton
@ -44,22 +44,12 @@ export const Notifications = () => {
icon={<PiLightbulbFilamentBold fontSize={20} />}
boxSize={8}
/>
{shouldShowNotificationIndicator && (
<Box
pos="absolute"
top={0}
right="2px"
w={2}
h={2}
backgroundColor="invokeYellow.500"
borderRadius="100%"
/>
)}
</Flex>
</PopoverTrigger>
<PopoverContent p={2}>
<PopoverArrow />
<PopoverHeader fontSize="md" fontWeight="semibold">
<PopoverCloseButton />
<PopoverHeader fontSize="md" fontWeight="semibold" pt={5}>
<Flex alignItems="center" gap={3}>
<Image src={InvokeSymbol} boxSize={6} />
{t('whatsNew.whatsNewInInvoke')}

View File

@ -13,7 +13,7 @@ const initialUIState: UIState = {
shouldShowProgressInViewer: true,
accordions: {},
expanders: {},
shouldShowNotificationIndicator: true,
shouldShowNotification: true,
};
export const uiSlice = createSlice({
@ -37,8 +37,8 @@ export const uiSlice = createSlice({
const { id, isOpen } = action.payload;
state.expanders[id] = isOpen;
},
shouldShowNotificationIndicatorChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowNotificationIndicator = action.payload;
shouldShowNotificationChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowNotification = action.payload;
},
},
extraReducers(builder) {
@ -54,7 +54,7 @@ export const {
setShouldShowProgressInViewer,
accordionStateChanged,
expanderStateChanged,
shouldShowNotificationIndicatorChanged,
shouldShowNotificationChanged,
} = uiSlice.actions;
export const selectUiSlice = (state: RootState) => state.ui;

View File

@ -26,7 +26,7 @@ export interface UIState {
*/
expanders: Record<string, boolean>;
/**
* Whether or not to show the user an indicator on notifications icon.
* Whether or not to show the user the open notification.
*/
shouldShowNotificationIndicator: boolean;
shouldShowNotification: boolean;
}

View File

@ -1,3 +1,4 @@
import type { StartQueryActionCreatorOptions } from '@reduxjs/toolkit/dist/query/core/buildInitiate';
import { getStore } from 'app/store/nanostores/store';
import type { SerializableObject } from 'common/types';
import type { BoardId } from 'features/gallery/store/types';
@ -568,25 +569,40 @@ export const {
/**
* Imperative RTKQ helper to fetch an ImageDTO.
* @param image_name The name of the image to fetch
* @param forceRefetch Whether to force a refetch of the image
* @returns
* @param options The options for the query. By default, the query will not subscribe to the store.
* @returns The ImageDTO if found, otherwise null
*/
export const getImageDTO = async (image_name: string, forceRefetch?: boolean): Promise<ImageDTO | null> => {
const options = {
export const getImageDTOSafe = async (
image_name: string,
options?: StartQueryActionCreatorOptions
): Promise<ImageDTO | null> => {
const _options = {
subscribe: false,
forceRefetch,
...options,
};
const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(image_name, options));
const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(image_name, _options));
try {
const imageDTO = await req.unwrap();
req.unsubscribe();
return imageDTO;
return await req.unwrap();
} catch {
req.unsubscribe();
return null;
}
};
/**
* Imperative RTKQ helper to fetch an ImageDTO.
* @param image_name The name of the image to fetch
* @param options The options for the query. By default, the query will not subscribe to the store.
* @raises Error if the image is not found or there is an error fetching the image
*/
export const getImageDTO = (image_name: string, options?: StartQueryActionCreatorOptions): Promise<ImageDTO> => {
const _options = {
subscribe: false,
...options,
};
const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(image_name, _options));
return req.unwrap();
};
export type UploadOptions = {
blob: Blob;
fileName: string;
@ -596,7 +612,7 @@ export type UploadOptions = {
board_id?: BoardId;
metadata?: SerializableObject;
};
export const uploadImage = async (arg: UploadOptions): Promise<ImageDTO> => {
export const uploadImage = (arg: UploadOptions): Promise<ImageDTO> => {
const { blob, fileName, image_category, is_intermediate, crop_visible = false, board_id, metadata } = arg;
const { dispatch } = getStore();
@ -612,5 +628,5 @@ export const uploadImage = async (arg: UploadOptions): Promise<ImageDTO> => {
})
);
req.reset();
return await req.unwrap();
return req.unwrap();
};

View File

@ -3330,38 +3330,6 @@ export type components = {
*/
type: "canvas_v2_mask_and_crop";
};
/** CanvasV2MaskAndCropOutput */
CanvasV2MaskAndCropOutput: {
/** @description The output image */
image: components["schemas"]["ImageField"];
/**
* Width
* @description The width of the image in pixels
*/
width: number;
/**
* Height
* @description The height of the image in pixels
*/
height: number;
/**
* type
* @default canvas_v2_mask_and_crop_output
* @constant
* @enum {string}
*/
type: "canvas_v2_mask_and_crop_output";
/**
* Offset X
* @description The x offset of the image, after cropping
*/
offset_x: number;
/**
* Offset Y
* @description The y offset of the image, after cropping
*/
offset_y: number;
};
/**
* Center Pad or Crop Image
* @description Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image.
@ -3428,9 +3396,10 @@ export type components = {
* - `Beta`: The invocation is not yet stable, but is planned to be stable in the future. Workflows built around this invocation may break, but we are committed to supporting this invocation long-term.
* - `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
* - `Deprecated`: The invocation is deprecated and may be removed in a future version.
* - `Internal`: The invocation is not intended for use by end-users. It may be changed or removed at any time, but is exposed for users to play with.
* @enum {string}
*/
Classification: "stable" | "beta" | "prototype" | "deprecated";
Classification: "stable" | "beta" | "prototype" | "deprecated" | "internal";
/**
* ClearResult
* @description Result of clearing the session queue
@ -6896,7 +6865,7 @@ export type components = {
* @description The results of node executions
*/
results?: {
[key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CanvasV2MaskAndCropOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxLoRALoaderOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
[key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxLoRALoaderOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
};
/**
* Errors
@ -9354,7 +9323,7 @@ export type components = {
* Result
* @description The result of the invocation
*/
result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CanvasV2MaskAndCropOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxLoRALoaderOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxLoRALoaderOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
};
/**
* InvocationDenoiseProgressEvent
@ -9524,7 +9493,7 @@ export type components = {
canny_edge_detection: components["schemas"]["ImageOutput"];
canny_image_processor: components["schemas"]["ImageOutput"];
canvas_paste_back: components["schemas"]["ImageOutput"];
canvas_v2_mask_and_crop: components["schemas"]["CanvasV2MaskAndCropOutput"];
canvas_v2_mask_and_crop: components["schemas"]["ImageOutput"];
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
collect: components["schemas"]["CollectInvocationOutput"];
color: components["schemas"]["ColorOutput"];

View File

@ -129,6 +129,10 @@ export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is M
return config.type === 'main' && config.base !== 'sdxl-refiner';
};
export const isCheckpointMainModelConfig = (config: AnyModelConfig): config is CheckpointModelConfig => {
return config.type === 'main' && (config.format === 'checkpoint' || config.format === 'bnb_quantized_nf4b');
};
export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'sdxl-refiner';
};

View File

@ -0,0 +1,23 @@
/**
* A custom error class for queue event errors. These errors have a type, message and traceback.
*/
export class QueueError extends Error {
type: string;
traceback: string;
constructor(type: string, message: string, traceback: string) {
super(message);
this.name = 'QueueError';
this.type = type;
this.traceback = traceback;
if (Error.captureStackTrace) {
Error.captureStackTrace(this, QueueError);
}
}
toString() {
return `${this.name} [${this.type}]: ${this.message}\nTraceback:\n${this.traceback}`;
}
}

View File

@ -7,7 +7,7 @@ import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } fro
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { boardsApi } from 'services/api/endpoints/boards';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO, S } from 'services/api/types';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { $lastProgressEvent } from 'services/events/stores';
@ -88,9 +88,7 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
const getResultImageDTO = (data: S['InvocationCompleteEvent']) => {
const { result } = data;
if (result.type === 'image_output') {
return getImageDTO(result.image.image_name);
} else if (result.type === 'canvas_v2_mask_and_crop_output') {
return getImageDTO(result.image.image_name);
return getImageDTOSafe(result.image.image_name);
}
return null;
};
@ -125,10 +123,7 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
if (data.destination === 'canvas') {
// TODO(psyche): Can/should we let canvas handle this itself?
if (isCanvasOutputNode(data)) {
if (data.result.type === 'canvas_v2_mask_and_crop_output') {
const { offset_x, offset_y } = data.result;
dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: offset_x, offsetY: offset_y } }));
} else if (data.result.type === 'image_output') {
if (data.result.type === 'image_output') {
dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
}
addImageToGallery(data, imageDTO);

View File

@ -192,4 +192,6 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches)
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
# differences are tolerable and expected due to the difference between sidecar vs. patching.
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)