From 4a83700fe4003ce44f4f6d8b94ed59018360a4c9 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Thu, 27 Mar 2025 15:08:57 -0400 Subject: [PATCH 01/21] if clientSideUploading is enabled, handle bulk uploads using that flow --- invokeai/app/api/routers/images.py | 16 ++ .../listeners/imageUploaded.ts | 36 +++- .../frontend/web/src/app/types/invokeai.ts | 2 +- .../src/common/hooks/useClientSideUpload.ts | 105 +++++++++ .../src/common/hooks/useImageUploadButton.tsx | 42 ++-- .../src/features/dnd/FullscreenDropzone.tsx | 47 ++-- .../components/GalleryUploadButton.tsx | 19 +- .../web/src/features/gallery/store/actions.ts | 7 + .../src/features/system/store/configSlice.ts | 4 +- .../web/src/services/api/endpoints/images.ts | 36 ++++ .../frontend/web/src/services/api/schema.ts | 201 ++++++++++++------ .../frontend/web/src/services/api/types.ts | 3 + 12 files changed, 376 insertions(+), 142 deletions(-) create mode 100644 invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index c86b554f9a..3760d672a9 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -96,6 +96,22 @@ async def upload_image( raise HTTPException(status_code=500, detail="Failed to create image") +class ImageUploadEntry(BaseModel): + image_dto: ImageDTO = Body(description="The image DTO") + presigned_url: str = Body(description="The URL to get the presigned URL for the image upload") + + +@images_router.post("/", operation_id="create_image_upload_entry") +async def create_image_upload_entry( + width: int = Body(description="The width of the image"), + height: int = Body(description="The height of the image"), + board_id: Optional[str] = Body(default=None, description="The board to add this image to, if any"), +) -> ImageUploadEntry: + """Uploads an image from a URL, not implemented""" + + raise HTTPException(status_code=501, detail="Not implemented") + + @images_router.delete("/i/{image_name}", operation_id="delete_image") async def delete_image( image_name: str = Path(description="The name of the image to delete"), diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index daeab9e3a5..da14203fd4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -1,6 +1,8 @@ +import { isAnyOf } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { RootState } from 'app/store/store'; +import { imageUploadedClientSide } from 'features/gallery/store/actions'; import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors'; import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice'; import { toast } from 'features/toast/toast'; @@ -8,7 +10,7 @@ import { t } from 'i18next'; import { omit } from 'lodash-es'; import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; - +import type { ImageDTO } from 'services/api/types'; const log = logger('gallery'); /** @@ -34,19 +36,33 @@ let lastUploadedToastTimeout: number | null = null; export const addImageUploadedFulfilledListener = (startAppListening: AppStartListening) => { startAppListening({ - matcher: imagesApi.endpoints.uploadImage.matchFulfilled, + matcher: isAnyOf(imagesApi.endpoints.uploadImage.matchFulfilled, imageUploadedClientSide), effect: (action, { dispatch, getState }) => { - const imageDTO = action.payload; + let imageDTO: ImageDTO; + let silent; + let isFirstUploadOfBatch = true; + + if (imageUploadedClientSide.match(action)) { + imageDTO = action.payload.imageDTO; + silent = action.payload.silent; + isFirstUploadOfBatch = action.payload.isFirstUploadOfBatch; + } else if (imagesApi.endpoints.uploadImage.matchFulfilled(action)) { + imageDTO = action.payload; + silent = action.meta.arg.originalArgs.silent; + isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true; + } else { + return; + } + + if (silent || imageDTO.is_intermediate) { + // If the image is silent or intermediate, we don't want to show a toast + return; + } + const state = getState(); log.debug({ imageDTO }, 'Image uploaded'); - if (action.meta.arg.originalArgs.silent || imageDTO.is_intermediate) { - // When a "silent" upload is requested, or the image is intermediate, we can skip all post-upload actions, - // like toasts and switching the gallery view - return; - } - const boardId = imageDTO.board_id ?? 'none'; const DEFAULT_UPLOADED_TOAST = { @@ -80,7 +96,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis * * Default to true to not require _all_ image upload handlers to set this value */ - const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true; + if (isFirstUploadOfBatch) { dispatch(boardIdSelected({ boardId })); dispatch(galleryViewChanged('assets')); diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 1a540d6743..a837894916 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -73,6 +73,7 @@ export type AppConfig = { maxUpscaleDimension?: number; allowPrivateBoards: boolean; allowPrivateStylePresets: boolean; + allowClientSideUpload: boolean; disabledTabs: TabName[]; disabledFeatures: AppFeature[]; disabledSDFeatures: SDFeature[]; @@ -81,7 +82,6 @@ export type AppConfig = { metadataFetchDebounce?: number; workflowFetchDebounce?: number; isLocal?: boolean; - maxImageUploadCount?: number; sd: { defaultModel?: string; disabledControlNetModels: string[]; diff --git a/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts b/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts new file mode 100644 index 0000000000..bc78794d77 --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useClientSideUpload.ts @@ -0,0 +1,105 @@ +import { useStore } from '@nanostores/react'; +import { $authToken } from 'app/store/nanostores/authToken'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { imageUploadedClientSide } from 'features/gallery/store/actions'; +import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; +import { useCallback } from 'react'; +import { useCreateImageUploadEntryMutation } from 'services/api/endpoints/images'; +import type { ImageDTO } from 'services/api/types'; +export const useClientSideUpload = () => { + const dispatch = useAppDispatch(); + const autoAddBoardId = useAppSelector(selectAutoAddBoardId); + const authToken = useStore($authToken); + const [createImageUploadEntry] = useCreateImageUploadEntryMutation(); + + const clientSideUpload = useCallback( + async (file: File, i: number): Promise => { + const image = new Image(); + const objectURL = URL.createObjectURL(file); + image.src = objectURL; + let width = 0; + let height = 0; + let thumbnail: Blob | undefined; + + await new Promise((resolve) => { + image.onload = () => { + width = image.naturalWidth; + height = image.naturalHeight; + + // Calculate thumbnail dimensions maintaining aspect ratio + let thumbWidth = width; + let thumbHeight = height; + if (width > height && width > 256) { + thumbWidth = 256; + thumbHeight = Math.round((height * 256) / width); + } else if (height > 256) { + thumbHeight = 256; + thumbWidth = Math.round((width * 256) / height); + } + + const canvas = document.createElement('canvas'); + canvas.width = thumbWidth; + canvas.height = thumbHeight; + const ctx = canvas.getContext('2d'); + ctx?.drawImage(image, 0, 0, thumbWidth, thumbHeight); + + canvas.toBlob( + (blob) => { + if (blob) { + thumbnail = blob; + // Clean up resources + URL.revokeObjectURL(objectURL); + image.src = ''; // Clear image source + image.remove(); // Remove the image element + canvas.width = 0; // Clear canvas + canvas.height = 0; + resolve(); + } + }, + 'image/webp', + 0.8 + ); + }; + + // Handle load errors + image.onerror = () => { + URL.revokeObjectURL(objectURL); + image.remove(); + resolve(); + }; + }); + const { presigned_url, image_dto } = await createImageUploadEntry({ + width, + height, + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, + }).unwrap(); + + await fetch(`${presigned_url}/?type=full`, { + method: 'PUT', + body: file, + ...(authToken && { + headers: { + Authorization: `Bearer ${authToken}`, + }, + }), + }); + + await fetch(`${presigned_url}/?type=thumbnail`, { + method: 'PUT', + body: thumbnail, + ...(authToken && { + headers: { + Authorization: `Bearer ${authToken}`, + }, + }), + }); + + dispatch(imageUploadedClientSide({ imageDTO: image_dto, silent: false, isFirstUploadOfBatch: i === 0 })); + + return image_dto; + }, + [autoAddBoardId, authToken, createImageUploadEntry, dispatch] + ); + + return clientSideUpload; +}; diff --git a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx index d0445fb61a..db2700f53a 100644 --- a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx +++ b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx @@ -3,7 +3,7 @@ import { IconButton } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; import { useAppSelector } from 'app/store/storeHooks'; import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; -import { selectMaxImageUploadCount } from 'features/system/store/configSlice'; +import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice'; import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import type { FileRejection } from 'react-dropzone'; @@ -15,6 +15,7 @@ import type { ImageDTO } from 'services/api/types'; import { assert } from 'tsafe'; import type { SetOptional } from 'type-fest'; +import { useClientSideUpload } from './useClientSideUpload'; type UseImageUploadButtonArgs = | { isDisabled?: boolean; @@ -50,8 +51,9 @@ const log = logger('gallery'); */ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => { const autoAddBoardId = useAppSelector(selectAutoAddBoardId); + const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled); const [uploadImage, request] = useUploadImageMutation(); - const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount); + const clientSideUpload = useClientSideUpload(); const { t } = useTranslation(); const onDropAccepted = useCallback( @@ -79,22 +81,27 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us onUpload(imageDTO); } } else { - const imageDTOs = await uploadImages( - files.map((file, i) => ({ - file, - image_category: 'user', - is_intermediate: false, - board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, - silent: false, - isFirstUploadOfBatch: i === 0, - })) - ); + let imageDTOs: ImageDTO[] = []; + if (isClientSideUploadEnabled) { + imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i))); + } else { + imageDTOs = await uploadImages( + files.map((file, i) => ({ + file, + image_category: 'user', + is_intermediate: false, + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, + silent: false, + isFirstUploadOfBatch: i === 0, + })) + ); + } if (onUpload) { onUpload(imageDTOs); } } }, - [allowMultiple, autoAddBoardId, onUpload, uploadImage] + [allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload] ); const onDropRejected = useCallback( @@ -105,10 +112,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us file: rejection.file.path, })); log.error({ errors }, 'Invalid upload'); - const description = - maxImageUploadCount === undefined - ? t('toast.uploadFailedInvalidUploadDesc') - : t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount }); + const description = t('toast.uploadFailedInvalidUploadDesc'); toast({ id: 'UPLOAD_FAILED', @@ -120,7 +124,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us return; } }, - [maxImageUploadCount, t] + [t] ); const { @@ -137,8 +141,6 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us onDropRejected, disabled: isDisabled, noDrag: true, - multiple: allowMultiple && (maxImageUploadCount === undefined || maxImageUploadCount > 1), - maxFiles: maxImageUploadCount, }); return { getUploadButtonProps, getUploadInputProps, openUploader, request }; diff --git a/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx b/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx index baa4c2d210..8da1968859 100644 --- a/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx +++ b/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx @@ -8,12 +8,13 @@ import { useStore } from '@nanostores/react'; import { getStore } from 'app/store/nanostores/store'; import { useAppSelector } from 'app/store/storeHooks'; import { $focusedRegion } from 'common/hooks/focus'; +import { useClientSideUpload } from 'common/hooks/useClientSideUpload'; import { setFileToPaste } from 'features/controlLayers/components/CanvasPasteModal'; import { DndDropOverlay } from 'features/dnd/DndDropOverlay'; import type { DndTargetState } from 'features/dnd/types'; import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer'; import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; -import { selectMaxImageUploadCount } from 'features/system/store/configSlice'; +import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice'; import { toast } from 'features/toast/toast'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import { memo, useCallback, useEffect, useRef, useState } from 'react'; @@ -53,13 +54,6 @@ const zUploadFile = z (file) => ({ message: `File extension .${file.name.split('.').at(-1)} is not supported` }) ); -const getFilesSchema = (max?: number) => { - if (max === undefined) { - return z.array(zUploadFile); - } - return z.array(zUploadFile).max(max); -}; - const sx = { position: 'absolute', top: 2, @@ -74,22 +68,19 @@ const sx = { export const FullscreenDropzone = memo(() => { const { t } = useTranslation(); const ref = useRef(null); - const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount); const [dndState, setDndState] = useState('idle'); const activeTab = useAppSelector(selectActiveTab); const isImageViewerOpen = useStore($imageViewer); + const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled); + const clientSideUpload = useClientSideUpload(); const validateAndUploadFiles = useCallback( - (files: File[]) => { + async (files: File[]) => { const { getState } = getStore(); - const uploadFilesSchema = getFilesSchema(maxImageUploadCount); - const parseResult = uploadFilesSchema.safeParse(files); + const parseResult = z.array(zUploadFile).safeParse(files); if (!parseResult.success) { - const description = - maxImageUploadCount === undefined - ? t('toast.uploadFailedInvalidUploadDesc') - : t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount }); + const description = t('toast.uploadFailedInvalidUploadDesc'); toast({ id: 'UPLOAD_FAILED', @@ -118,17 +109,23 @@ export const FullscreenDropzone = memo(() => { const autoAddBoardId = selectAutoAddBoardId(getState()); - const uploadArgs: UploadImageArg[] = files.map((file, i) => ({ - file, - image_category: 'user', - is_intermediate: false, - board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, - isFirstUploadOfBatch: i === 0, - })); + if (isClientSideUploadEnabled) { + for (const [i, file] of files.entries()) { + await clientSideUpload(file, i); + } + } else { + const uploadArgs: UploadImageArg[] = files.map((file, i) => ({ + file, + image_category: 'user', + is_intermediate: false, + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, + isFirstUploadOfBatch: i === 0, + })); - uploadImages(uploadArgs); + uploadImages(uploadArgs); + } }, - [activeTab, isImageViewerOpen, maxImageUploadCount, t] + [activeTab, isImageViewerOpen, t, isClientSideUploadEnabled, clientSideUpload] ); const onPaste = useCallback( diff --git a/invokeai/frontend/web/src/features/gallery/components/GalleryUploadButton.tsx b/invokeai/frontend/web/src/features/gallery/components/GalleryUploadButton.tsx index ba9c87a90e..ca6cc78051 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GalleryUploadButton.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GalleryUploadButton.tsx @@ -1,31 +1,18 @@ import { IconButton } from '@invoke-ai/ui-library'; -import { useAppSelector } from 'app/store/storeHooks'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; -import { selectMaxImageUploadCount } from 'features/system/store/configSlice'; import { t } from 'i18next'; -import { useMemo } from 'react'; import { PiUploadBold } from 'react-icons/pi'; export const GalleryUploadButton = () => { - const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount); - const uploadOptions = useMemo(() => ({ allowMultiple: maxImageUploadCount !== 1 }), [maxImageUploadCount]); - const uploadApi = useImageUploadButton(uploadOptions); + const uploadApi = useImageUploadButton({ allowMultiple: true }); return ( <> 1 - ? t('accessibility.uploadImages') - : t('accessibility.uploadImage') - } - tooltip={ - maxImageUploadCount === undefined || maxImageUploadCount > 1 - ? t('accessibility.uploadImages') - : t('accessibility.uploadImage') - } + aria-label={t('accessibility.uploadImages')} + tooltip={t('accessibility.uploadImages')} icon={} {...uploadApi.getUploadButtonProps()} /> diff --git a/invokeai/frontend/web/src/features/gallery/store/actions.ts b/invokeai/frontend/web/src/features/gallery/store/actions.ts index 75c5d5846f..8d13c44936 100644 --- a/invokeai/frontend/web/src/features/gallery/store/actions.ts +++ b/invokeai/frontend/web/src/features/gallery/store/actions.ts @@ -1,4 +1,5 @@ import { createAction } from '@reduxjs/toolkit'; +import type { ImageDTO } from 'services/api/types'; export const sentImageToCanvas = createAction('gallery/sentImageToCanvas'); @@ -7,3 +8,9 @@ export const imageDownloaded = createAction('gallery/imageDownloaded'); export const imageCopiedToClipboard = createAction('gallery/imageCopiedToClipboard'); export const imageOpenedInNewTab = createAction('gallery/imageOpenedInNewTab'); + +export const imageUploadedClientSide = createAction<{ + imageDTO: ImageDTO; + silent: boolean; + isFirstUploadOfBatch: boolean; +}>('gallery/imageUploadedClientSide'); diff --git a/invokeai/frontend/web/src/features/system/store/configSlice.ts b/invokeai/frontend/web/src/features/system/store/configSlice.ts index 2d524d244e..e48226ea1c 100644 --- a/invokeai/frontend/web/src/features/system/store/configSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/configSlice.ts @@ -20,6 +20,7 @@ const initialConfigState: AppConfig = { shouldFetchMetadataFromApi: false, allowPrivateBoards: false, allowPrivateStylePresets: false, + allowClientSideUpload: false, disabledTabs: [], disabledFeatures: ['lightbox', 'faceRestore', 'batches'], disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'], @@ -218,6 +219,5 @@ export const selectWorkflowFetchDebounce = createConfigSelector((config) => conf export const selectMetadataFetchDebounce = createConfigSelector((config) => config.metadataFetchDebounce ?? 300); export const selectIsModelsTabDisabled = createConfigSelector((config) => config.disabledTabs.includes('models')); -export const selectMaxImageUploadCount = createConfigSelector((config) => config.maxImageUploadCount); - +export const selectIsClientSideUploadEnabled = createConfigSelector((config) => config.allowClientSideUpload); export const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal); diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index a8ad744b3b..adefdb62da 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -7,6 +7,8 @@ import type { DeleteBoardResult, GraphAndWorkflowResponse, ImageDTO, + ImageUploadEntryRequest, + ImageUploadEntryResponse, ListImagesArgs, ListImagesResponse, UploadImageArg, @@ -287,6 +289,7 @@ export const imagesApi = api.injectEndpoints({ }, }; }, + invalidatesTags: (result) => { if (!result || result.is_intermediate) { // Don't add it to anything @@ -314,7 +317,39 @@ export const imagesApi = api.injectEndpoints({ ]; }, }), + createImageUploadEntry: build.mutation({ + query: ({ width, height }) => ({ + url: buildImagesUrl(), + method: 'POST', + body: { width, height }, + }), + invalidatesTags: (result) => { + if (!result) { + // Don't add it to anything + return []; + } + const categories = getCategories(result.image_dto); + const boardId = result.image_dto.board_id ?? 'none'; + return [ + { + type: 'ImageList', + id: getListImagesUrl({ + board_id: boardId, + categories, + }), + }, + { + type: 'Board', + id: boardId, + }, + { + type: 'BoardImagesTotal', + id: boardId, + }, + ]; + }, + }), deleteBoard: build.mutation({ query: (board_id) => ({ url: buildBoardsUrl(board_id), method: 'DELETE' }), invalidatesTags: () => [ @@ -549,6 +584,7 @@ export const { useGetImageWorkflowQuery, useLazyGetImageWorkflowQuery, useUploadImageMutation, + useCreateImageUploadEntryMutation, useClearIntermediatesMutation, useAddImagesToBoardMutation, useRemoveImagesFromBoardMutation, diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 21da55167e..2472eef495 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -466,6 +466,30 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/images/": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * List Image Dtos + * @description Gets a list of image DTOs + */ + get: operations["list_image_dtos"]; + put?: never; + /** + * Create Image Upload Entry + * @description Uploads an image from a URL, not implemented + */ + post: operations["create_image_upload_entry"]; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/images/i/{image_name}": { parameters: { query?: never; @@ -619,26 +643,6 @@ export type paths = { patch?: never; trace?: never; }; - "/api/v1/images/": { - parameters: { - query?: never; - header?: never; - path?: never; - cookie?: never; - }; - /** - * List Image Dtos - * @description Gets a list of image DTOs - */ - get: operations["list_image_dtos"]; - put?: never; - post?: never; - delete?: never; - options?: never; - head?: never; - patch?: never; - trace?: never; - }; "/api/v1/images/delete": { parameters: { query?: never; @@ -2358,6 +2362,24 @@ export type components = { */ batch_ids: string[]; }; + /** Body_create_image_upload_entry */ + Body_create_image_upload_entry: { + /** + * Width + * @description The width of the image + */ + width: number; + /** + * Height + * @description The height of the image + */ + height: number; + /** + * Board Id + * @description The board to add this image to, if any + */ + board_id?: string | null; + }; /** Body_create_style_preset */ Body_create_style_preset: { /** @@ -10754,6 +10776,16 @@ export type components = { */ type: "i2l"; }; + /** ImageUploadEntry */ + ImageUploadEntry: { + /** @description The image DTO */ + image_dto: components["schemas"]["ImageDTO"]; + /** + * Presigned Url + * @description The URL to get the presigned URL for the image upload + */ + presigned_url: string; + }; /** * ImageUrlsDTO * @description The URLs for an image and its thumbnail. @@ -23219,6 +23251,87 @@ export interface operations { }; }; }; + list_image_dtos: { + parameters: { + query?: { + /** @description The origin of images to list. */ + image_origin?: components["schemas"]["ResourceOrigin"] | null; + /** @description The categories of image to include. */ + categories?: components["schemas"]["ImageCategory"][] | null; + /** @description Whether to list intermediate images. */ + is_intermediate?: boolean | null; + /** @description The board id to filter by. Use 'none' to find images without a board. */ + board_id?: string | null; + /** @description The page offset */ + offset?: number; + /** @description The number of images per page */ + limit?: number; + /** @description The order of sort */ + order_dir?: components["schemas"]["SQLiteDirection"]; + /** @description Whether to sort by starred images first */ + starred_first?: boolean; + /** @description The term to search for */ + search_term?: string | null; + }; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["OffsetPaginatedResults_ImageDTO_"]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + create_image_upload_entry: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["Body_create_image_upload_entry"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["ImageUploadEntry"]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; get_image_dto: { parameters: { query?: never; @@ -23572,54 +23685,6 @@ export interface operations { }; }; }; - list_image_dtos: { - parameters: { - query?: { - /** @description The origin of images to list. */ - image_origin?: components["schemas"]["ResourceOrigin"] | null; - /** @description The categories of image to include. */ - categories?: components["schemas"]["ImageCategory"][] | null; - /** @description Whether to list intermediate images. */ - is_intermediate?: boolean | null; - /** @description The board id to filter by. Use 'none' to find images without a board. */ - board_id?: string | null; - /** @description The page offset */ - offset?: number; - /** @description The number of images per page */ - limit?: number; - /** @description The order of sort */ - order_dir?: components["schemas"]["SQLiteDirection"]; - /** @description Whether to sort by starred images first */ - starred_first?: boolean; - /** @description The term to search for */ - search_term?: string | null; - }; - header?: never; - path?: never; - cookie?: never; - }; - requestBody?: never; - responses: { - /** @description Successful Response */ - 200: { - headers: { - [name: string]: unknown; - }; - content: { - "application/json": components["schemas"]["OffsetPaginatedResults_ImageDTO_"]; - }; - }; - /** @description Validation Error */ - 422: { - headers: { - [name: string]: unknown; - }; - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; delete_images_from_list: { parameters: { query?: never; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 9692bf2f77..cf498377e3 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -354,3 +354,6 @@ export type UploadImageArg = { */ isFirstUploadOfBatch?: boolean; }; + +export type ImageUploadEntryResponse = S['ImageUploadEntry']; +export type ImageUploadEntryRequest = paths['/api/v1/images/']['post']['requestBody']['content']['application/json']; From e50c7e5947468e6f4f132df71088f68a88573a0d Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Thu, 27 Mar 2025 15:14:02 -0400 Subject: [PATCH 02/21] restore multiple key --- invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx index db2700f53a..9fedc501f5 100644 --- a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx +++ b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx @@ -141,6 +141,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us onDropRejected, disabled: isDisabled, noDrag: true, + multiple: allowMultiple, }); return { getUploadButtonProps, getUploadInputProps, openUploader, request }; From 3f58c68c096e1ab074918e436617535cb09369a6 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Thu, 27 Mar 2025 15:40:09 -0400 Subject: [PATCH 03/21] fix tag invalidation --- .../listeners/imageUploaded.ts | 24 +++++++++++++++ .../web/src/services/api/endpoints/images.ts | 30 ++----------------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index da14203fd4..3a681bf431 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -11,6 +11,7 @@ import { omit } from 'lodash-es'; import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; +import { getCategories, getListImagesUrl } from 'services/api/util'; const log = logger('gallery'); /** @@ -59,6 +60,29 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis return; } + if (imageUploadedClientSide.match(action)) { + const categories = getCategories(imageDTO); + const boardId = imageDTO.board_id ?? 'none'; + dispatch( + imagesApi.util.invalidateTags([ + { + type: 'ImageList', + id: getListImagesUrl({ + board_id: boardId, + categories, + }), + }, + { + type: 'Board', + id: boardId, + }, + { + type: 'BoardImagesTotal', + id: boardId, + }, + ]) + ); + } const state = getState(); log.debug({ imageDTO }, 'Image uploaded'); diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index adefdb62da..8860ebf3f1 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -318,37 +318,11 @@ export const imagesApi = api.injectEndpoints({ }, }), createImageUploadEntry: build.mutation({ - query: ({ width, height }) => ({ + query: ({ width, height, board_id }) => ({ url: buildImagesUrl(), method: 'POST', - body: { width, height }, + body: { width, height, board_id }, }), - invalidatesTags: (result) => { - if (!result) { - // Don't add it to anything - return []; - } - const categories = getCategories(result.image_dto); - const boardId = result.image_dto.board_id ?? 'none'; - - return [ - { - type: 'ImageList', - id: getListImagesUrl({ - board_id: boardId, - categories, - }), - }, - { - type: 'Board', - id: boardId, - }, - { - type: 'BoardImagesTotal', - id: boardId, - }, - ]; - }, }), deleteBoard: build.mutation({ query: (board_id) => ({ url: buildBoardsUrl(board_id), method: 'DELETE' }), From 542b1828999512438d02dcbc81cf99ed0014ec93 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 6 Mar 2025 07:35:06 +1000 Subject: [PATCH 04/21] ci: use uv for python-checks --- .github/workflows/python-checks.yml | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/.github/workflows/python-checks.yml b/.github/workflows/python-checks.yml index d81ace9e27..bcb1b0c1ef 100644 --- a/.github/workflows/python-checks.yml +++ b/.github/workflows/python-checks.yml @@ -34,6 +34,9 @@ on: jobs: python-checks: + env: + # uv requires a venv by default - but for this, we can simply use the system python + UV_SYSTEM_PYTHON: 1 runs-on: ubuntu-latest timeout-minutes: 5 # expected run time: <1 min steps: @@ -57,25 +60,23 @@ jobs: - '!invokeai/frontend/web/**' - 'tests/**' - - name: setup python + - name: setup uv if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} - uses: actions/setup-python@v5 + uses: astral-sh/setup-uv@v5 with: - python-version: '3.10' - cache: pip - cache-dependency-path: pyproject.toml + version: '0.6.4' + enable-cache: true - - name: install ruff + - name: install python if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} - run: pip install ruff==0.9.9 - shell: bash + run: uv python install - name: ruff check if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} - run: ruff check --output-format=github . + run: uv tool run ruff check --output-format=github . shell: bash - name: ruff format if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} - run: ruff format --check . + run: uv tool run ruff format --check . shell: bash From c0f88a083e7831646f0fae3646299b965d038d8e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 6 Mar 2025 07:38:03 +1000 Subject: [PATCH 05/21] ci: use uv for python-tests --- .github/workflows/python-tests.yml | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 27411f2f20..bcbbb061fd 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -70,6 +70,9 @@ jobs: timeout-minutes: 15 # expected run time: 2-6 min, depending on platform env: PIP_USE_PEP517: '1' + # uv requires a venv by default - but for this, we can simply use the system python + UV_SYSTEM_PYTHON: 1 + steps: - name: checkout # https://github.com/nschloe/action-cached-lfs-checkout @@ -92,21 +95,24 @@ jobs: - '!invokeai/frontend/web/**' - 'tests/**' - - name: setup python + - name: setup uv if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} - uses: actions/setup-python@v5 + uses: astral-sh/setup-uv@v5 with: + version: '0.6.4' + enable-cache: true python-version: ${{ matrix.python-version }} - cache: pip - cache-dependency-path: pyproject.toml + + - name: install python + if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} + run: uv python install - name: install dependencies if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} env: - PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }} - run: > - pip3 install --editable=".[test]" + UV_INDEX: ${{ matrix.extra-index-url }} + run: uv pip install --editable ".[test]" - name: run pytest if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} - run: pytest + run: uv run pytest -v From 403f795c5e8993859a26a5208e79798280611e0e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 6 Mar 2025 07:38:46 +1000 Subject: [PATCH 06/21] ci: remove linux-cuda-11_7 & linux-rocm-5_2 from test matrix We only have CPU runners, so these tests are not doing anything useful. --- .github/workflows/python-tests.yml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index bcbbb061fd..e48a8154ca 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -42,19 +42,10 @@ jobs: - '3.10' - '3.11' platform: - - linux-cuda-11_7 - - linux-rocm-5_2 - linux-cpu - macos-default - windows-cpu include: - - platform: linux-cuda-11_7 - os: ubuntu-22.04 - github-env: $GITHUB_ENV - - platform: linux-rocm-5_2 - os: ubuntu-22.04 - extra-index-url: 'https://download.pytorch.org/whl/rocm5.2' - github-env: $GITHUB_ENV - platform: linux-cpu os: ubuntu-22.04 extra-index-url: 'https://download.pytorch.org/whl/cpu' From 96c0393fe743f3bd6d8668a63c31e0d8879ee688 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 6 Mar 2025 07:54:32 +1000 Subject: [PATCH 07/21] ci: bump ruff to 0.11.2 Need to bump both CI and pyproject.toml at the same time --- .github/workflows/python-checks.yml | 4 ++-- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-checks.yml b/.github/workflows/python-checks.yml index bcb1b0c1ef..dd706df1f4 100644 --- a/.github/workflows/python-checks.yml +++ b/.github/workflows/python-checks.yml @@ -73,10 +73,10 @@ jobs: - name: ruff check if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} - run: uv tool run ruff check --output-format=github . + run: uv tool run ruff@0.11.2 check --output-format=github . shell: bash - name: ruff format if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} - run: uv tool run ruff format --check . + run: uv tool run ruff@0.11.2 format --check . shell: bash diff --git a/pyproject.toml b/pyproject.toml index 1eaabbdfed..6f9a61b88b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,7 +117,7 @@ dependencies = [ ] "dev" = ["jurigged", "pudb", "snakeviz", "gprof2dot"] "test" = [ - "ruff~=0.9.9", + "ruff~=0.11.2", "ruff-lsp~=0.0.62", "mypy", "pre-commit", From 7acaa86bdf8e309401c2a0a97b6436dbc3be976c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 6 Mar 2025 07:55:55 +1000 Subject: [PATCH 08/21] ci: get ci working with uv instead of pip Lots of squashed experimentation heh: ci: manually specify python version in tests ci: whoops typo in ruff cmds ci: specify python versions for uv python install ci: install python verbosely ci: try forcing python preference? ci: try forcing python preference a different way? ci: try in a venv? ci: it works, but try without venv ci: oh maybe we need --preview? ci: poking it with a stick ci: it works, add summary to pytest output ci: fix pytest output experiment: simulate test failure Revert "experiment: simulate test failure" This reverts commit b99ca512f6e61a2a04a1c0636d44018c11019954. ci: just use default pytest output cI: attempt again to use uv to install python cI: attempt again again to use uv to install python Revert "cI: attempt again again to use uv to install python" This reverts commit 3cba861c90738081caeeb3eca97b60656ab63929. Revert "cI: attempt again to use uv to install python" This reverts commit b30f2277041dc999ed514f6c594c6d6a78f5c810. --- .github/workflows/python-checks.yml | 4 ---- .github/workflows/python-tests.yml | 9 +++++---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/.github/workflows/python-checks.yml b/.github/workflows/python-checks.yml index dd706df1f4..a4d740c2a2 100644 --- a/.github/workflows/python-checks.yml +++ b/.github/workflows/python-checks.yml @@ -67,10 +67,6 @@ jobs: version: '0.6.4' enable-cache: true - - name: install python - if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} - run: uv python install - - name: ruff check if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} run: uv tool run ruff@0.11.2 check --output-format=github . diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index e48a8154ca..a8b1c9662a 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -61,7 +61,6 @@ jobs: timeout-minutes: 15 # expected run time: 2-6 min, depending on platform env: PIP_USE_PEP517: '1' - # uv requires a venv by default - but for this, we can simply use the system python UV_SYSTEM_PYTHON: 1 steps: @@ -94,9 +93,11 @@ jobs: enable-cache: true python-version: ${{ matrix.python-version }} - - name: install python + - name: setup python if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} - run: uv python install + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} - name: install dependencies if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} @@ -106,4 +107,4 @@ jobs: - name: run pytest if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} - run: uv run pytest -v + run: pytest From 168e5eeff08418e5bff37bc9afd8418d07736d70 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 6 Mar 2025 09:12:40 +1000 Subject: [PATCH 09/21] ci: use uv in typegen-checks ci: use uv in typegen-checks to generate types experiment: simulate typegen-checks failure Revert "experiment: simulate typegen-checks failure" This reverts commit f53c6876fe8311de236d974194abce93ed84930c. --- .github/workflows/typegen-checks.yml | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/.github/workflows/typegen-checks.yml b/.github/workflows/typegen-checks.yml index 9ac8b81f77..d1360d7bc5 100644 --- a/.github/workflows/typegen-checks.yml +++ b/.github/workflows/typegen-checks.yml @@ -54,17 +54,25 @@ jobs: - 'pyproject.toml' - 'invokeai/**' + - name: setup uv + if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }} + uses: astral-sh/setup-uv@v5 + with: + version: '0.6.4' + enable-cache: true + python-version: '3.11' + - name: setup python if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }} uses: actions/setup-python@v5 with: - python-version: '3.10' - cache: pip - cache-dependency-path: pyproject.toml + python-version: '3.11' - - name: install python dependencies + - name: install dependencies if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }} - run: pip3 install --use-pep517 --editable="." + env: + UV_INDEX: ${{ matrix.extra-index-url }} + run: uv pip install --editable . - name: install frontend dependencies if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }} @@ -77,7 +85,7 @@ jobs: - name: generate schema if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }} - run: make frontend-typegen + run: cd invokeai/frontend/web && uv run ../../../scripts/generate_openapi_schema.py | pnpm typegen shell: bash - name: compare files From ed9b30efdad1496c86e8e922cc31ec0e61488eda Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 29 Mar 2025 07:26:20 +1000 Subject: [PATCH 10/21] ci: bump uv to 0.6.10 --- .github/workflows/python-checks.yml | 2 +- .github/workflows/python-tests.yml | 2 +- .github/workflows/typegen-checks.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-checks.yml b/.github/workflows/python-checks.yml index a4d740c2a2..3fbca99e51 100644 --- a/.github/workflows/python-checks.yml +++ b/.github/workflows/python-checks.yml @@ -64,7 +64,7 @@ jobs: if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} uses: astral-sh/setup-uv@v5 with: - version: '0.6.4' + version: '0.6.10' enable-cache: true - name: ruff check diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index a8b1c9662a..388d33d5a4 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -89,7 +89,7 @@ jobs: if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }} uses: astral-sh/setup-uv@v5 with: - version: '0.6.4' + version: '0.6.10' enable-cache: true python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/typegen-checks.yml b/.github/workflows/typegen-checks.yml index d1360d7bc5..d706559dbc 100644 --- a/.github/workflows/typegen-checks.yml +++ b/.github/workflows/typegen-checks.yml @@ -58,7 +58,7 @@ jobs: if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }} uses: astral-sh/setup-uv@v5 with: - version: '0.6.4' + version: '0.6.10' enable-cache: true python-version: '3.11' From b0fdc8ae1c5cbef06e59915ac0367454ead47dfc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 29 Mar 2025 07:27:59 +1000 Subject: [PATCH 11/21] ci: bump linux-cpu test runner to ubuntu 24.04 --- .github/workflows/python-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 388d33d5a4..ff167c1685 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -47,7 +47,7 @@ jobs: - windows-cpu include: - platform: linux-cpu - os: ubuntu-22.04 + os: ubuntu-24.04 extra-index-url: 'https://download.pytorch.org/whl/cpu' github-env: $GITHUB_ENV - platform: macos-default From 47cb61cd62df102b18680288aa19d1e65d3a5346 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 29 Mar 2025 07:29:48 +1000 Subject: [PATCH 12/21] ci: remove python 3.10 from test matrix --- .github/workflows/python-tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index ff167c1685..a4a7e38bc2 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -39,7 +39,6 @@ jobs: strategy: matrix: python-version: - - '3.10' - '3.11' platform: - linux-cpu From f6d770eac9752c830c23fbd5688a2810c4253381 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 29 Mar 2025 07:29:58 +1000 Subject: [PATCH 13/21] ci: add python 3.12 to test matrix --- .github/workflows/python-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index a4a7e38bc2..6d7e942e56 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -40,6 +40,7 @@ jobs: matrix: python-version: - '3.11' + - '3.12' platform: - linux-cpu - macos-default From aaa6211625d1bab803d35485751eac332dd08739 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 29 Mar 2025 07:46:44 +1000 Subject: [PATCH 14/21] chore(backend): ruff C420 --- tests/backend/flux/controlnet/test_state_dict_utils.py | 4 ++-- tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/backend/flux/controlnet/test_state_dict_utils.py b/tests/backend/flux/controlnet/test_state_dict_utils.py index 54868a4af7..2688dd3beb 100644 --- a/tests/backend/flux/controlnet/test_state_dict_utils.py +++ b/tests/backend/flux/controlnet/test_state_dict_utils.py @@ -24,7 +24,7 @@ from tests.backend.flux.controlnet.xlabs_flux_controlnet_state_dict import xlabs ], ) def test_is_state_dict_xlabs_controlnet(sd_shapes: dict[str, list[int]], expected: bool): - sd = {k: None for k in sd_shapes} + sd = dict.fromkeys(sd_shapes) assert is_state_dict_xlabs_controlnet(sd) == expected @@ -37,7 +37,7 @@ def test_is_state_dict_xlabs_controlnet(sd_shapes: dict[str, list[int]], expecte ], ) def test_is_state_dict_instantx_controlnet(sd_keys: list[str], expected: bool): - sd = {k: None for k in sd_keys} + sd = dict.fromkeys(sd_keys) assert is_state_dict_instantx_controlnet(sd) == expected diff --git a/tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py b/tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py index 6010bab652..359658eb43 100644 --- a/tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py +++ b/tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py @@ -19,7 +19,7 @@ from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_v2_state_dict import xl @pytest.mark.parametrize("sd_shapes", [xlabs_flux_ip_adapter_sd_shapes, xlabs_flux_ip_adapter_v2_sd_shapes]) def test_is_state_dict_xlabs_ip_adapter(sd_shapes: dict[str, list[int]]): # Construct a dummy state_dict. - sd = {k: None for k in sd_shapes} + sd = dict.fromkeys(sd_shapes) assert is_state_dict_xlabs_ip_adapter(sd) From 4109ea53240b31166354d6bd65febcadc2a03af5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 28 Mar 2025 14:27:33 +1000 Subject: [PATCH 15/21] fix(nodes): expanded masks not 100% transparent outside the fade out region The polynomial fit isn't perfect and we end up with alpha values of 1 instead of 0 when applying the mask. This in turn causes issues on canvas where outputs aren't 100% transparent and individual layer bbox calculations are incorrect. --- invokeai/app/invocations/image.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 7fc9219954..3f26f169f8 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -1089,7 +1089,7 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard): @invocation( - "expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.0" + "expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1" ) class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard): """Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard. @@ -1147,8 +1147,21 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard): coeffs = numpy.polyfit(x_control, y_control, 3) poly = numpy.poly1d(coeffs) - # Evaluate and clip the smooth mapping - feather = numpy.clip(poly(d_norm), 0, 1) + # Evaluate the polynomial + feather = poly(d_norm) + + # The polynomial fit isn't perfect. Points beyond the fade distance are likely to be slightly less than 1.0, + # even though the control points indicate that they should be exactly 1.0. This is due to the nature of the + # polynomial fit, which is a best approximation of the control points but not an exact match. + + # When this occurs, the area outside the mask and fade-out will not be 100% transparent. For example, it may + # have an alpha value of 1 instead of 0. So we must force pixels at or beyond the fade distance to exactly 1.0. + + # Force pixels at or beyond the fade distance to exactly 1.0 + feather = numpy.where(d_norm >= 1.0, 1.0, feather) + + # Clip any other values to ensure they're in the valid range [0,1] + feather = numpy.clip(feather, 0, 1) # Build final image. np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8)) From 96fb5f68818449149da943e05752a4696a98bd71 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 30 Mar 2025 20:10:29 +1000 Subject: [PATCH 16/21] feat(ui): disable denoising strength when selected models flux fill --- .../components/ParamDenoisingStrength.tsx | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx index 4314c47c01..bf4464bd5b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx @@ -14,8 +14,9 @@ import WavyLine from 'common/components/WavyLine'; import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice'; import { selectActiveRasterLayerEntities } from 'features/controlLayers/store/selectors'; import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig'; const selectHasRasterLayersWithContent = createSelector( selectActiveRasterLayerEntities, @@ -26,6 +27,7 @@ export const ParamDenoisingStrength = memo(() => { const img2imgStrength = useAppSelector(selectImg2imgStrength); const dispatch = useAppDispatch(); const hasRasterLayersWithContent = useAppSelector(selectHasRasterLayersWithContent); + const selectedModelConfig = useSelectedModelConfig(); const onChange = useCallback( (v: number) => { @@ -39,8 +41,24 @@ export const ParamDenoisingStrength = memo(() => { const [invokeBlue300] = useToken('colors', ['invokeBlue.300']); + const isDisabled = useMemo(() => { + if (!hasRasterLayersWithContent) { + // Denoising strength does nothing if there are no raster layers w/ content + return true; + } + if ( + selectedModelConfig?.type === 'main' && + selectedModelConfig?.base === 'flux' && + selectedModelConfig.variant === 'inpaint' + ) { + // Denoising strength is ignored by FLUX Fill, which is indicated by the variant being 'inpaint' + return true; + } + return false; + }, [hasRasterLayersWithContent, selectedModelConfig]); + return ( - + {`${t('parameters.denoisingStrength')}`} @@ -49,7 +67,7 @@ export const ParamDenoisingStrength = memo(() => { )} - {hasRasterLayersWithContent ? ( + {!isDisabled ? ( <> Date: Mon, 31 Mar 2025 09:10:45 +1000 Subject: [PATCH 17/21] fix(mm): handle FLUX models w/ diff in_channels keys Before FLUX Fill was merged, we didn't do any checks for the model variant. We always returned "normal". To determine if a model is a FLUX Fill model, we need to check the state dict for a specific key. Initially, this logic was too strict and rejected quantized FLUX models. This issue was resolved, but it turns out there is another failure mode - some fine-tunes use a different key. This change further reduces the strictness, handling the alternate key and also falling back to "normal" if we don't see either key. This effectively restores the previous probing behaviour for all FLUX models. Closes #7856 Closes #7859 --- .../backend/flux/flux_state_dict_utils.py | 23 +++++++++++++++++++ .../backend/model_manager/legacy_probe.py | 10 +++++++- 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 invokeai/backend/flux/flux_state_dict_utils.py diff --git a/invokeai/backend/flux/flux_state_dict_utils.py b/invokeai/backend/flux/flux_state_dict_utils.py new file mode 100644 index 0000000000..8ffab54c68 --- /dev/null +++ b/invokeai/backend/flux/flux_state_dict_utils.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from invokeai.backend.model_manager.legacy_probe import CkptType + + +def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None: + """Gets the in channels from the state dict.""" + + # "Standard" FLUX models use "img_in.weight", but some community fine tunes use + # "model.diffusion_model.img_in.weight". Known models that use the latter key: + # - https://civitai.com/models/885098?modelVersionId=990775 + # - https://civitai.com/models/1018060?modelVersionId=1596255 + # - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133 + + keys = {"img_in.weight", "model.diffusion_model.img_in.weight"} + + for key in keys: + val = state_dict.get(key) + if val is not None: + return val.shape[1] + + return None diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index 304fbce346..24a5a9f527 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -14,6 +14,7 @@ from invokeai.backend.flux.controlnet.state_dict_utils import ( is_state_dict_instantx_controlnet, is_state_dict_xlabs_controlnet, ) +from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash @@ -564,7 +565,14 @@ class CheckpointProbeBase(ProbeBase): state_dict = self.checkpoint.get("state_dict") or self.checkpoint if base_type == BaseModelType.Flux: - in_channels = state_dict["img_in.weight"].shape[1] + in_channels = get_flux_in_channels_from_state_dict(state_dict) + + if in_channels is None: + # If we cannot find the in_channels, we assume that this is a normal variant. Log a warning. + logger.warning( + f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant." + ) + return ModelVariantType.Normal # FLUX Model variant types are distinguished by input channels: # - Unquantized Dev and Schnell have in_channels=64 From 8b299d0bace4dfae5b6e191d9ba2e067906131e5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 31 Mar 2025 12:28:20 +1100 Subject: [PATCH 18/21] chore: prep for v5.9.1 --- invokeai/version/invokeai_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/version/invokeai_version.py b/invokeai/version/invokeai_version.py index b6a1a593e4..b98cc20695 100644 --- a/invokeai/version/invokeai_version.py +++ b/invokeai/version/invokeai_version.py @@ -1 +1 @@ -__version__ = "5.9.0" +__version__ = "5.9.1" From 7be87c8048d6c18431bbedbb2ee9cd2073e7b984 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 8 Mar 2025 18:38:59 +1000 Subject: [PATCH 19/21] refactor(nodes): simpler logic for baseinvocation typeadapter handling --- invokeai/app/invocations/baseinvocation.py | 47 +++++++++------------- tests/test_config.py | 4 +- 2 files changed, 20 insertions(+), 31 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 31ac02ce8e..148e4993be 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -8,6 +8,7 @@ import sys import warnings from abc import ABC, abstractmethod from enum import Enum +from functools import lru_cache from inspect import signature from typing import ( TYPE_CHECKING, @@ -27,7 +28,6 @@ import semver from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined -from typing_extensions import TypeAliasType from invokeai.app.invocations.fields import ( FieldKind, @@ -101,14 +101,12 @@ class BaseInvocationOutput(BaseModel): """ _output_classes: ClassVar[set[BaseInvocationOutput]] = set() - _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None - _typeadapter_needs_update: ClassVar[bool] = False @classmethod def register_output(cls, output: BaseInvocationOutput) -> None: """Registers an invocation output.""" cls._output_classes.add(output) - cls._typeadapter_needs_update = True + cls.get_typeadapter.cache_clear() @classmethod def get_outputs(cls) -> Iterable[BaseInvocationOutput]: @@ -116,15 +114,15 @@ class BaseInvocationOutput(BaseModel): return cls._output_classes @classmethod + @lru_cache(maxsize=1) def get_typeadapter(cls) -> TypeAdapter[Any]: - """Gets a pydantc TypeAdapter for the union of all invocation output types.""" - if not cls._typeadapter or cls._typeadapter_needs_update: - AnyInvocationOutput = TypeAliasType( - "AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] - ) - cls._typeadapter = TypeAdapter(AnyInvocationOutput) - cls._typeadapter_needs_update = False - return cls._typeadapter + """Gets a pydantic TypeAdapter for the union of all invocation output types. + + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or + denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects + the updated allowlist and denylist. + """ + return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]) @classmethod def get_output_types(cls) -> Iterable[str]: @@ -174,8 +172,6 @@ class BaseInvocation(ABC, BaseModel): """ _invocation_classes: ClassVar[set[BaseInvocation]] = set() - _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None - _typeadapter_needs_update: ClassVar[bool] = False @classmethod def get_type(cls) -> str: @@ -186,25 +182,18 @@ class BaseInvocation(ABC, BaseModel): def register_invocation(cls, invocation: BaseInvocation) -> None: """Registers an invocation.""" cls._invocation_classes.add(invocation) - cls._typeadapter_needs_update = True + cls.get_typeadapter.cache_clear() @classmethod + @lru_cache(maxsize=1) def get_typeadapter(cls) -> TypeAdapter[Any]: - """Gets a pydantc TypeAdapter for the union of all invocation types.""" - if not cls._typeadapter or cls._typeadapter_needs_update: - AnyInvocation = TypeAliasType( - "AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")] - ) - cls._typeadapter = TypeAdapter(AnyInvocation) - cls._typeadapter_needs_update = False - return cls._typeadapter + """Gets a pydantic TypeAdapter for the union of all invocation types. - @classmethod - def invalidate_typeadapter(cls) -> None: - """Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or - denylist is changed, this should be called to ensure the typeadapter is updated and validation respects - the updated allowlist and denylist.""" - cls._typeadapter_needs_update = True + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or + denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects + the updated allowlist and denylist. + """ + return TypeAdapter(Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]) @classmethod def get_invocations(cls) -> Iterable[BaseInvocation]: diff --git a/tests/test_config.py b/tests/test_config.py index 220d6f257a..a4e7b0038c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir): # We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for # subsequent graph validations - BaseInvocation.invalidate_typeadapter() + BaseInvocation.get_typeadapter.cache_clear() # confirm graph validation fails when using denied node Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}}) @@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir): # Reset the config so that it doesn't affect other tests get_config.cache_clear() - BaseInvocation.invalidate_typeadapter() + BaseInvocation.get_typeadapter.cache_clear() From 6155f9ff9ec91c7676365a1879cdeebbbba91d4e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 11 Mar 2025 10:02:22 +1000 Subject: [PATCH 20/21] feat(nodes): move invocation/output registration to separate class --- invokeai/app/invocations/baseinvocation.py | 172 ++++++++++----------- invokeai/app/services/shared/graph.py | 9 +- invokeai/app/util/custom_openapi.py | 9 +- tests/test_config.py | 8 +- 4 files changed, 101 insertions(+), 97 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 148e4993be..aa1dbe3af4 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -100,35 +100,6 @@ class BaseInvocationOutput(BaseModel): All invocation outputs must use the `@invocation_output` decorator to provide their unique type. """ - _output_classes: ClassVar[set[BaseInvocationOutput]] = set() - - @classmethod - def register_output(cls, output: BaseInvocationOutput) -> None: - """Registers an invocation output.""" - cls._output_classes.add(output) - cls.get_typeadapter.cache_clear() - - @classmethod - def get_outputs(cls) -> Iterable[BaseInvocationOutput]: - """Gets all invocation outputs.""" - return cls._output_classes - - @classmethod - @lru_cache(maxsize=1) - def get_typeadapter(cls) -> TypeAdapter[Any]: - """Gets a pydantic TypeAdapter for the union of all invocation output types. - - This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or - denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects - the updated allowlist and denylist. - """ - return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]) - - @classmethod - def get_output_types(cls) -> Iterable[str]: - """Gets all invocation output types.""" - return (i.get_type() for i in BaseInvocationOutput.get_outputs()) - @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None: """Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" @@ -171,67 +142,16 @@ class BaseInvocation(ABC, BaseModel): All invocations must use the `@invocation` decorator to provide their unique type. """ - _invocation_classes: ClassVar[set[BaseInvocation]] = set() - @classmethod def get_type(cls) -> str: """Gets the invocation's type, as provided by the `@invocation` decorator.""" return cls.model_fields["type"].default - @classmethod - def register_invocation(cls, invocation: BaseInvocation) -> None: - """Registers an invocation.""" - cls._invocation_classes.add(invocation) - cls.get_typeadapter.cache_clear() - - @classmethod - @lru_cache(maxsize=1) - def get_typeadapter(cls) -> TypeAdapter[Any]: - """Gets a pydantic TypeAdapter for the union of all invocation types. - - This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or - denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects - the updated allowlist and denylist. - """ - return TypeAdapter(Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]) - - @classmethod - def get_invocations(cls) -> Iterable[BaseInvocation]: - """Gets all invocations, respecting the allowlist and denylist.""" - app_config = get_config() - allowed_invocations: set[BaseInvocation] = set() - for sc in cls._invocation_classes: - invocation_type = sc.get_type() - is_in_allowlist = ( - invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True - ) - is_in_denylist = ( - invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False - ) - if is_in_allowlist and not is_in_denylist: - allowed_invocations.add(sc) - return allowed_invocations - - @classmethod - def get_invocations_map(cls) -> dict[str, BaseInvocation]: - """Gets a map of all invocation types to their invocation classes.""" - return {i.get_type(): i for i in BaseInvocation.get_invocations()} - - @classmethod - def get_invocation_types(cls) -> Iterable[str]: - """Gets all invocation types.""" - return (i.get_type() for i in BaseInvocation.get_invocations()) - @classmethod def get_output_annotation(cls) -> BaseInvocationOutput: """Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method).""" return signature(cls.invoke).return_annotation - @classmethod - def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None: - """Gets the invocation class for a given invocation type.""" - return cls.get_invocations_map().get(invocation_type) - @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None: """Adds various UI-facing attributes to the invocation's OpenAPI schema.""" @@ -329,6 +249,87 @@ class BaseInvocation(ABC, BaseModel): TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation) +class InvocationRegistry: + _invocation_classes: ClassVar[set[type[BaseInvocation]]] = set() + _output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set() + + @classmethod + def register_invocation(cls, invocation: type[BaseInvocation]) -> None: + """Registers an invocation.""" + cls._invocation_classes.add(invocation) + cls.get_invocation_typeadapter.cache_clear() + + @classmethod + @lru_cache(maxsize=1) + def get_invocation_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantic TypeAdapter for the union of all invocation types. + + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or + denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects + the updated allowlist and denylist. + """ + return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")]) + + @classmethod + def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]: + """Gets all invocations, respecting the allowlist and denylist.""" + app_config = get_config() + allowed_invocations: set[type[BaseInvocation]] = set() + for sc in cls._invocation_classes: + invocation_type = sc.get_type() + is_in_allowlist = ( + invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True + ) + is_in_denylist = ( + invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False + ) + if is_in_allowlist and not is_in_denylist: + allowed_invocations.add(sc) + return allowed_invocations + + @classmethod + def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]: + """Gets a map of all invocation types to their invocation classes.""" + return {i.get_type(): i for i in cls.get_invocation_classes()} + + @classmethod + def get_invocation_types(cls) -> Iterable[str]: + """Gets all invocation types.""" + return (i.get_type() for i in cls.get_invocation_classes()) + + @classmethod + def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None: + """Gets the invocation class for a given invocation type.""" + return cls.get_invocations_map().get(invocation_type) + + @classmethod + def register_output(cls, output: "type[TBaseInvocationOutput]") -> None: + """Registers an invocation output.""" + cls._output_classes.add(output) + cls.get_output_typeadapter.cache_clear() + + @classmethod + def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]: + """Gets all invocation outputs.""" + return cls._output_classes + + @classmethod + @lru_cache(maxsize=1) + def get_output_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantic TypeAdapter for the union of all invocation output types. + + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or + denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects + the updated allowlist and denylist. + """ + return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]) + + @classmethod + def get_output_types(cls) -> Iterable[str]: + """Gets all invocation output types.""" + return (i.get_type() for i in cls.get_output_classes()) + + RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = { "id", "is_intermediate", @@ -442,8 +443,8 @@ def invocation( node_pack = cls.__module__.split(".")[0] # Handle the case where an existing node is being clobbered by the one we are registering - if invocation_type in BaseInvocation.get_invocation_types(): - clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type) + if invocation_type in InvocationRegistry.get_invocation_types(): + clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type) # This should always be true - we just checked if the invocation type was in the set assert clobbered_invocation is not None @@ -528,8 +529,7 @@ def invocation( ) cls.__doc__ = docstring - # TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic - BaseInvocation.register_invocation(cls) # type: ignore + InvocationRegistry.register_invocation(cls) return cls @@ -554,7 +554,7 @@ def invocation_output( if re.compile(r"^\S+$").match(output_type) is None: raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"') - if output_type in BaseInvocationOutput.get_output_types(): + if output_type in InvocationRegistry.get_output_types(): raise ValueError(f'Invocation type "{output_type}" already exists') validate_fields(cls.model_fields, output_type) @@ -575,7 +575,7 @@ def invocation_output( ) cls.__doc__ = docstring - BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly? + InvocationRegistry.register_output(cls) return cls diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 2d425a7515..fef99753ee 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -21,6 +21,7 @@ from invokeai.app.invocations import * # noqa: F401 F403 from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + InvocationRegistry, invocation, invocation_output, ) @@ -283,7 +284,7 @@ class AnyInvocation(BaseInvocation): @classmethod def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: def validate_invocation(v: Any) -> "AnyInvocation": - return BaseInvocation.get_typeadapter().validate_python(v) + return InvocationRegistry.get_invocation_typeadapter().validate_python(v) return core_schema.no_info_plain_validator_function(validate_invocation) @@ -294,7 +295,7 @@ class AnyInvocation(BaseInvocation): # Nodes are too powerful, we have to make our own OpenAPI schema manually # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually oneOf: list[dict[str, str]] = [] - names = [i.__name__ for i in BaseInvocation.get_invocations()] + names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()] for name in sorted(names): oneOf.append({"$ref": f"#/components/schemas/{name}"}) return {"oneOf": oneOf} @@ -304,7 +305,7 @@ class AnyInvocationOutput(BaseInvocationOutput): @classmethod def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): def validate_invocation_output(v: Any) -> "AnyInvocationOutput": - return BaseInvocationOutput.get_typeadapter().validate_python(v) + return InvocationRegistry.get_output_typeadapter().validate_python(v) return core_schema.no_info_plain_validator_function(validate_invocation_output) @@ -316,7 +317,7 @@ class AnyInvocationOutput(BaseInvocationOutput): # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually oneOf: list[dict[str, str]] = [] - names = [i.__name__ for i in BaseInvocationOutput.get_outputs()] + names = [i.__name__ for i in InvocationRegistry.get_output_classes()] for name in sorted(names): oneOf.append({"$ref": f"#/components/schemas/{name}"}) return {"oneOf": oneOf} diff --git a/invokeai/app/util/custom_openapi.py b/invokeai/app/util/custom_openapi.py index e52028d772..85e17e21c6 100644 --- a/invokeai/app/util/custom_openapi.py +++ b/invokeai/app/util/custom_openapi.py @@ -4,7 +4,10 @@ from fastapi import FastAPI from fastapi.openapi.utils import get_openapi from pydantic.json_schema import models_json_schema -from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase +from invokeai.app.invocations.baseinvocation import ( + InvocationRegistry, + UIConfigBase, +) from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.events.events_common import EventBase @@ -56,14 +59,14 @@ def get_openapi_func( invocation_output_map_required: list[str] = [] # We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly. - for output in BaseInvocationOutput.get_outputs(): + for output in InvocationRegistry.get_output_classes(): json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") move_defs_to_top_level(openapi_schema, json_schema) openapi_schema["components"]["schemas"][output.__name__] = json_schema # Technically, invocations are added to the schema by pydantic, but we still need to manually set their output # property, so we'll just do it all manually. - for invocation in BaseInvocation.get_invocations(): + for invocation in InvocationRegistry.get_invocation_classes(): json_schema = invocation.model_json_schema( mode="serialization", ref_template="#/components/schemas/{model}" ) diff --git a/tests/test_config.py b/tests/test_config.py index a4e7b0038c..087dc89db5 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,7 @@ from typing import Any import pytest from pydantic import ValidationError -from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.app.invocations.baseinvocation import InvocationRegistry from invokeai.app.services.config.config_default import ( DefaultInvokeAIAppConfig, InvokeAIAppConfig, @@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir): # We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for # subsequent graph validations - BaseInvocation.get_typeadapter.cache_clear() + InvocationRegistry.get_invocation_typeadapter.cache_clear() # confirm graph validation fails when using denied node Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}}) @@ -284,7 +284,7 @@ def test_deny_nodes(patch_rootdir): Graph.model_validate({"nodes": {"1": {"id": "1", "type": "float"}}}) # confirm invocations union will not have denied nodes - all_invocations = BaseInvocation.get_invocations() + all_invocations = InvocationRegistry.get_invocation_classes() has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1 has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1 @@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir): # Reset the config so that it doesn't affect other tests get_config.cache_clear() - BaseInvocation.get_typeadapter.cache_clear() + InvocationRegistry.get_invocation_typeadapter.cache_clear() From 595133463e50cc3fd759aaf63cb9916c497c4a63 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 24 Mar 2025 09:40:33 +1000 Subject: [PATCH 21/21] feat(nodes): add methods to invalidate invocation typeadapters --- invokeai/app/invocations/baseinvocation.py | 22 ++++++++++++++++++++-- tests/test_config.py | 4 ++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index aa1dbe3af4..b2f3c3a9f2 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -257,19 +257,28 @@ class InvocationRegistry: def register_invocation(cls, invocation: type[BaseInvocation]) -> None: """Registers an invocation.""" cls._invocation_classes.add(invocation) - cls.get_invocation_typeadapter.cache_clear() + cls.invalidate_invocation_typeadapter() @classmethod @lru_cache(maxsize=1) def get_invocation_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantic TypeAdapter for the union of all invocation types. + This is used to parse serialized invocations into the correct invocation class. + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects the updated allowlist and denylist. + + @see https://docs.pydantic.dev/latest/concepts/type_adapter/ """ return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")]) + @classmethod + def invalidate_invocation_typeadapter(cls) -> None: + """Invalidates the cached invocation type adapter.""" + cls.get_invocation_typeadapter.cache_clear() + @classmethod def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]: """Gets all invocations, respecting the allowlist and denylist.""" @@ -306,7 +315,7 @@ class InvocationRegistry: def register_output(cls, output: "type[TBaseInvocationOutput]") -> None: """Registers an invocation output.""" cls._output_classes.add(output) - cls.get_output_typeadapter.cache_clear() + cls.invalidate_output_typeadapter() @classmethod def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]: @@ -318,12 +327,21 @@ class InvocationRegistry: def get_output_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantic TypeAdapter for the union of all invocation output types. + This is used to parse serialized invocation outputs into the correct invocation output class. + This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects the updated allowlist and denylist. + + @see https://docs.pydantic.dev/latest/concepts/type_adapter/ """ return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]) + @classmethod + def invalidate_output_typeadapter(cls) -> None: + """Invalidates the cached invocation output type adapter.""" + cls.get_output_typeadapter.cache_clear() + @classmethod def get_output_types(cls) -> Iterable[str]: """Gets all invocation output types.""" diff --git a/tests/test_config.py b/tests/test_config.py index 087dc89db5..610162c075 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir): # We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for # subsequent graph validations - InvocationRegistry.get_invocation_typeadapter.cache_clear() + InvocationRegistry.invalidate_invocation_typeadapter() # confirm graph validation fails when using denied node Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}}) @@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir): # Reset the config so that it doesn't affect other tests get_config.cache_clear() - InvocationRegistry.get_invocation_typeadapter.cache_clear() + InvocationRegistry.invalidate_invocation_typeadapter()