mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
feat: more batch types (wip)
This commit is contained in:
parent
e44458609f
commit
64e5c6add7
@ -13,6 +13,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
BoardField,
|
||||
BoundingBoxField,
|
||||
ColorField,
|
||||
ConditioningField,
|
||||
@ -544,7 +545,7 @@ class BoundingBoxInvocation(BaseInvocation):
|
||||
@invocation(
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "internal"],
|
||||
tags=["primitives", "image", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
@ -559,3 +560,43 @@ class ImageBatchInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
|
||||
@invocation(
|
||||
"string_batch",
|
||||
title="String Batch",
|
||||
tags=["primitives", "string", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class StringBatchInvocation(BaseInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
|
||||
|
||||
strings: list[str] = InputField(min_length=1, description="The strings to batch over", input=Input.Direct)
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
|
||||
@invocation(
|
||||
"board_batch",
|
||||
title="Board Batch",
|
||||
tags=["primitives", "image", "board", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class BoardBatchInvocation(BaseInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch. The images are populated from the selected board."""
|
||||
|
||||
board_to_batch: BoardField = InputField(description="The board to batch over", title="Board", input=Input.Direct)
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
@ -30,9 +30,17 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro
|
||||
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
fieldImageCollectionValueChanged,
|
||||
fieldImageValueChanged,
|
||||
fieldStringCollectionValueChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import {
|
||||
type FieldIdentifier,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { uniqBy } from 'lodash-es';
|
||||
@ -124,6 +132,76 @@ export const removeImageFromNodeImageFieldCollectionAction = (arg: {
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
|
||||
};
|
||||
|
||||
export const addStringToNodeStringFieldCollectionAction = (arg: {
|
||||
value: string;
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { value, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isStringFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add strings to a non-string field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
fieldValue.push(value);
|
||||
dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue }));
|
||||
};
|
||||
|
||||
export const removeStringFromNodeStringFieldCollectionAction = (arg: {
|
||||
index: number;
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { index, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isStringFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to remove string to a non-string field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
fieldValue.splice(index, 1);
|
||||
dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue }));
|
||||
};
|
||||
|
||||
export const changeStringOnNodeStringFieldCollectionAction = (arg: {
|
||||
index: number;
|
||||
value: string;
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { index, value, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isStringFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add strings to a non-string field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
fieldValue.splice(index, 1, value);
|
||||
dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue }));
|
||||
};
|
||||
|
||||
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
|
||||
const { imageDTO, dispatch } = arg;
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
|
||||
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import {
|
||||
@ -49,6 +50,8 @@ import {
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSpandrelImageToImageModelFieldInputInstance,
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
@ -94,6 +97,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||
|
||||
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <StringFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
||||
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldBoardValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { BoardFieldInputInstance, BoardFieldInputTemplate } from 'features/nodes/types/field';
|
||||
@ -17,17 +18,18 @@ const BoardFieldInputComponent = (props: FieldComponentProps<BoardFieldInputInst
|
||||
{ include_archived: true },
|
||||
{
|
||||
selectFromResult: ({ data }) => {
|
||||
const options: ComboboxOption[] = [
|
||||
{
|
||||
label: 'None',
|
||||
value: 'none',
|
||||
},
|
||||
].concat(
|
||||
(data ?? []).map(({ board_id, board_name }) => ({
|
||||
label: board_name,
|
||||
value: board_id,
|
||||
}))
|
||||
);
|
||||
if (!data) {
|
||||
return {
|
||||
options: EMPTY_ARRAY,
|
||||
hasBoards: false,
|
||||
};
|
||||
}
|
||||
|
||||
const options: ComboboxOption[] = data.map(({ board_id, board_name, image_count }) => ({
|
||||
label: `${board_name} (${image_count})`,
|
||||
value: board_id,
|
||||
}));
|
||||
|
||||
return {
|
||||
options,
|
||||
hasBoards: options.length > 1,
|
||||
@ -45,14 +47,14 @@ const BoardFieldInputComponent = (props: FieldComponentProps<BoardFieldInputInst
|
||||
fieldBoardValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: v.value !== 'none' ? { board_id: v.value } : undefined,
|
||||
value: { board_id: v.value },
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value?.board_id), [options, field.value]);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value?.board_id) ?? null, [options, field.value]);
|
||||
|
||||
const noOptionsMessage = useCallback(() => t('boards.noMatching'), [t]);
|
||||
|
||||
|
@ -90,7 +90,7 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
isError={isInvalid}
|
||||
onUpload={onUpload}
|
||||
fontSize={24}
|
||||
variant="outline"
|
||||
variant="ghost"
|
||||
/>
|
||||
)}
|
||||
{field.value && field.value.length > 0 && (
|
||||
|
@ -0,0 +1,168 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Grid, GridItem, IconButton, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import {
|
||||
addStringToNodeStringFieldCollectionAction,
|
||||
changeStringOnNodeStringFieldCollectionAction,
|
||||
removeStringFromNodeStringFieldCollectionAction,
|
||||
} from 'features/imageActions/actions';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import type {
|
||||
StringFieldCollectionInputInstance,
|
||||
StringFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
const sx = {
|
||||
borderWidth: 1,
|
||||
'&[data-error=true]': {
|
||||
borderColor: 'error.500',
|
||||
borderStyle: 'solid',
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const StringFieldCollectionInputComponent = memo(
|
||||
(props: FieldComponentProps<StringFieldCollectionInputInstance, StringFieldCollectionInputTemplate>) => {
|
||||
const { t } = useTranslation();
|
||||
const { nodeId, field } = props;
|
||||
const store = useAppStore();
|
||||
|
||||
const isInvalid = useFieldIsInvalid(nodeId, field.name);
|
||||
|
||||
const onRemoveString = useCallback(
|
||||
(index: number) => {
|
||||
removeStringFromNodeStringFieldCollectionAction({
|
||||
index,
|
||||
fieldIdentifier: { nodeId, fieldName: field.name },
|
||||
dispatch: store.dispatch,
|
||||
getState: store.getState,
|
||||
});
|
||||
},
|
||||
[field.name, nodeId, store.dispatch, store.getState]
|
||||
);
|
||||
|
||||
const onChangeString = useCallback(
|
||||
(index: number, value: string) => {
|
||||
changeStringOnNodeStringFieldCollectionAction({
|
||||
index,
|
||||
value,
|
||||
fieldIdentifier: { nodeId, fieldName: field.name },
|
||||
dispatch: store.dispatch,
|
||||
getState: store.getState,
|
||||
});
|
||||
},
|
||||
[field.name, nodeId, store.dispatch, store.getState]
|
||||
);
|
||||
|
||||
const onAddString = useCallback(() => {
|
||||
addStringToNodeStringFieldCollectionAction({
|
||||
value: '',
|
||||
fieldIdentifier: { nodeId, fieldName: field.name },
|
||||
dispatch: store.dispatch,
|
||||
getState: store.getState,
|
||||
});
|
||||
}, [field.name, nodeId, store]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
className="nodrag"
|
||||
position="relative"
|
||||
w="full"
|
||||
h="full"
|
||||
maxH={64}
|
||||
alignItems="stretch"
|
||||
justifyContent="center"
|
||||
>
|
||||
{(!field.value || field.value.length === 0) && (
|
||||
<Box w="full" sx={sx} data-error={isInvalid} borderRadius="base">
|
||||
<IconButton
|
||||
w="full"
|
||||
onClick={onAddString}
|
||||
aria-label="Add Item"
|
||||
icon={<PiPlusBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
{field.value && field.value.length > 0 && (
|
||||
<Box w="full" h="auto" p={1} sx={sx} data-error={isInvalid} borderRadius="base">
|
||||
<OverlayScrollbarsComponent
|
||||
className="nowheel"
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid w="full" h="full" templateColumns="repeat(1, 1fr)" gap={1}>
|
||||
<IconButton
|
||||
onClick={onAddString}
|
||||
aria-label="Add Item"
|
||||
icon={<PiPlusBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
{field.value.map((value, index) => (
|
||||
<GridItem key={index} position="relative" className="nodrag">
|
||||
<StringListItemContent
|
||||
value={value}
|
||||
index={index}
|
||||
onRemoveString={onRemoveString}
|
||||
onChangeString={onChangeString}
|
||||
/>
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent';
|
||||
|
||||
type StringListItemContentProps = {
|
||||
value: string;
|
||||
index: number;
|
||||
onRemoveString: (index: number) => void;
|
||||
onChangeString: (index: number, value: string) => void;
|
||||
};
|
||||
|
||||
const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveString(index);
|
||||
}, [index, onRemoveString]);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
onChangeString(index, e.target.value);
|
||||
},
|
||||
[index, onChangeString]
|
||||
);
|
||||
return (
|
||||
<Flex alignItems="center" gap={1}>
|
||||
<Textarea size="xs" resize="none" value={value} onChange={onChange} />
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={onClickRemove}
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('common.remove')}
|
||||
tooltip={t('common.remove')}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
StringListItemContent.displayName = 'StringListItemContent';
|
@ -3,7 +3,12 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import {
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
|
||||
@ -46,6 +51,18 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
|
||||
}
|
||||
}
|
||||
|
||||
// Else special handling for individual field types
|
||||
if (isStringFieldCollectionInputInstance(field) && isStringFieldCollectionInputTemplate(template)) {
|
||||
// String collections may have min or max item counts
|
||||
if (template.minItems !== undefined && field.value.length < template.minItems) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (template.maxItems !== undefined && field.value.length > template.maxItems) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Field looks OK
|
||||
return false;
|
||||
});
|
||||
|
@ -27,6 +27,7 @@ import type {
|
||||
SDXLRefinerModelFieldValue,
|
||||
SpandrelImageToImageModelFieldValue,
|
||||
StatefulFieldValue,
|
||||
StringFieldCollectionValue,
|
||||
StringFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
T5EncoderModelFieldValue,
|
||||
@ -54,6 +55,7 @@ import {
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zStatefulFieldValue,
|
||||
zStringFieldCollectionValue,
|
||||
zStringFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zT5EncoderModelFieldValue,
|
||||
@ -309,6 +311,9 @@ export const nodesSlice = createSlice({
|
||||
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStringFieldValue);
|
||||
},
|
||||
fieldStringCollectionValueChanged: (state, action: FieldValueAction<StringFieldCollectionValue>) => {
|
||||
fieldValueReducer(state, action, zStringFieldCollectionValue);
|
||||
},
|
||||
fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => {
|
||||
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
|
||||
},
|
||||
@ -433,6 +438,7 @@ export const {
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldStringCollectionValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
@ -543,6 +549,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldStringCollectionValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
|
@ -86,6 +86,15 @@ const zStringFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StringField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStringCollectionFieldType = z.object({
|
||||
name: z.literal('StringField'),
|
||||
cardinality: z.literal(COLLECTION),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
export const isStringCollectionFieldType = (
|
||||
fieldType: FieldType
|
||||
): fieldType is z.infer<typeof zStringCollectionFieldType> => zStringCollectionFieldType.safeParse(fieldType).success;
|
||||
|
||||
const zBooleanFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BooleanField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@ -310,6 +319,52 @@ const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
});
|
||||
|
||||
// #region StringField Collection
|
||||
export const zStringFieldCollectionValue = z.array(zStringFieldValue).optional();
|
||||
const zStringFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zStringFieldCollectionValue,
|
||||
});
|
||||
const zStringFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
.extend({
|
||||
type: zStringCollectionFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zStringFieldCollectionValue,
|
||||
maxLength: z.number().int().gte(0).optional(),
|
||||
minLength: z.number().int().gte(0).optional(),
|
||||
maxItems: z.number().int().gte(0).optional(),
|
||||
minItems: z.number().int().gte(0).optional(),
|
||||
})
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val.maxLength !== undefined && val.minLength !== undefined) {
|
||||
return val.maxLength >= val.minLength;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxLength must be greater than or equal to minLength' }
|
||||
)
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val.maxItems !== undefined && val.minItems !== undefined) {
|
||||
return val.maxItems >= val.minItems;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxItems must be greater than or equal to minItems' }
|
||||
);
|
||||
|
||||
const zStringFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStringCollectionFieldType,
|
||||
});
|
||||
export type StringFieldCollectionValue = z.infer<typeof zStringFieldCollectionValue>;
|
||||
export type StringFieldCollectionInputInstance = z.infer<typeof zStringFieldCollectionInputInstance>;
|
||||
export type StringFieldCollectionInputTemplate = z.infer<typeof zStringFieldCollectionInputTemplate>;
|
||||
export const isStringFieldCollectionInputInstance = (val: unknown): val is StringFieldCollectionInputInstance =>
|
||||
zStringFieldCollectionInputInstance.safeParse(val).success;
|
||||
export const isStringFieldCollectionInputTemplate = (val: unknown): val is StringFieldCollectionInputTemplate =>
|
||||
zStringFieldCollectionInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
||||
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
|
||||
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
|
||||
@ -409,7 +464,7 @@ const zImageFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxLength must be greater than or equal to minLength' }
|
||||
{ message: 'maxItems must be greater than or equal to minItems' }
|
||||
);
|
||||
|
||||
const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@ -937,6 +992,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zIntegerFieldValue,
|
||||
zFloatFieldValue,
|
||||
zStringFieldValue,
|
||||
zStringFieldCollectionValue,
|
||||
zBooleanFieldValue,
|
||||
zEnumFieldValue,
|
||||
zImageFieldValue,
|
||||
@ -973,6 +1029,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zIntegerFieldInputInstance,
|
||||
zFloatFieldInputInstance,
|
||||
zStringFieldInputInstance,
|
||||
zStringFieldCollectionInputInstance,
|
||||
zBooleanFieldInputInstance,
|
||||
zEnumFieldInputInstance,
|
||||
zImageFieldInputInstance,
|
||||
@ -1008,6 +1065,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zIntegerFieldInputTemplate,
|
||||
zFloatFieldInputTemplate,
|
||||
zStringFieldInputTemplate,
|
||||
zStringFieldCollectionInputTemplate,
|
||||
zBooleanFieldInputTemplate,
|
||||
zEnumFieldInputTemplate,
|
||||
zImageFieldInputTemplate,
|
||||
@ -1046,6 +1104,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zIntegerFieldOutputTemplate,
|
||||
zFloatFieldOutputTemplate,
|
||||
zStringFieldOutputTemplate,
|
||||
zStringFieldCollectionOutputTemplate,
|
||||
zBooleanFieldOutputTemplate,
|
||||
zEnumFieldOutputTemplate,
|
||||
zImageFieldOutputTemplate,
|
||||
|
@ -27,12 +27,17 @@ import type {
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
StatefulFieldType,
|
||||
StatelessFieldInputTemplate,
|
||||
StringFieldCollectionInputTemplate,
|
||||
StringFieldInputTemplate,
|
||||
T2IAdapterModelFieldInputTemplate,
|
||||
T5EncoderModelFieldInputTemplate,
|
||||
VAEModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { isImageCollectionFieldType, isStatefulFieldType } from 'features/nodes/types/field';
|
||||
import {
|
||||
isImageCollectionFieldType,
|
||||
isStatefulFieldType,
|
||||
isStringCollectionFieldType,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { InvocationFieldSchema } from 'features/nodes/types/openapi';
|
||||
import { isSchemaObject } from 'features/nodes/types/openapi';
|
||||
import { t } from 'i18next';
|
||||
@ -132,6 +137,36 @@ const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputT
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringFieldCollectionInputTemplate: FieldInputTemplateBuilder<StringFieldCollectionInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: StringFieldCollectionInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
|
||||
};
|
||||
|
||||
if (schemaObject.minLength !== undefined) {
|
||||
template.minLength = schemaObject.minLength;
|
||||
}
|
||||
|
||||
if (schemaObject.maxLength !== undefined) {
|
||||
template.maxLength = schemaObject.maxLength;
|
||||
}
|
||||
|
||||
if (schemaObject.minItems !== undefined) {
|
||||
template.minItems = schemaObject.minItems;
|
||||
}
|
||||
|
||||
if (schemaObject.maxItems !== undefined) {
|
||||
template.maxItems = schemaObject.maxItems;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -553,12 +588,18 @@ export const buildFieldInputTemplate = (
|
||||
|
||||
if (isStatefulFieldType(fieldType)) {
|
||||
if (isImageCollectionFieldType(fieldType)) {
|
||||
fieldType;
|
||||
return buildImageFieldCollectionInputTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else if (isStringCollectionFieldType(fieldType)) {
|
||||
fieldType;
|
||||
return buildStringFieldCollectionInputTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else {
|
||||
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
|
||||
const template = builder({
|
||||
|
@ -11,7 +11,12 @@ import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import type { NodesState, Templates } from 'features/nodes/store/types';
|
||||
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import {
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
|
||||
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
@ -123,6 +128,37 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
|
||||
});
|
||||
return;
|
||||
}
|
||||
} else if (
|
||||
field.value &&
|
||||
isStringFieldCollectionInputInstance(field) &&
|
||||
isStringFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
// String collections may have min or max items to validate
|
||||
// TODO(psyche): generalize this to other collection types
|
||||
if (fieldTemplate.minItems !== undefined && fieldTemplate.minItems > 0 && field.value.length === 0) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.collectionEmpty', baseTKeyOptions) });
|
||||
return;
|
||||
}
|
||||
if (fieldTemplate.minItems !== undefined && field.value.length < fieldTemplate.minItems) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.collectionTooFewItems', {
|
||||
...baseTKeyOptions,
|
||||
size: field.value.length,
|
||||
minItems: fieldTemplate.minItems,
|
||||
}),
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (fieldTemplate.maxItems !== undefined && field.value.length > fieldTemplate.maxItems) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.collectionTooManyItems', {
|
||||
...baseTKeyOptions,
|
||||
size: field.value.length,
|
||||
maxItems: fieldTemplate.maxItems,
|
||||
}),
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user