feat: more batch types (wip)

This commit is contained in:
psychedelicious 2024-11-27 14:01:07 +11:00
parent e44458609f
commit 64e5c6add7
No known key found for this signature in database
11 changed files with 478 additions and 22 deletions

View File

@ -13,6 +13,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
BoardField,
BoundingBoxField,
ColorField,
ConditioningField,
@ -544,7 +545,7 @@ class BoundingBoxInvocation(BaseInvocation):
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "internal"],
tags=["primitives", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
@ -559,3 +560,43 @@ class ImageBatchInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")
@invocation(
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
strings: list[str] = InputField(min_length=1, description="The strings to batch over", input=Input.Direct)
def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")
def invoke(self, context: InvocationContext) -> StringOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")
@invocation(
"board_batch",
title="Board Batch",
tags=["primitives", "image", "board", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class BoardBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch. The images are populated from the selected board."""
board_to_batch: BoardField = InputField(description="The board to batch over", title="Board", input=Input.Direct)
def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")

View File

@ -30,9 +30,17 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import type { BoardId } from 'features/gallery/store/types';
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import {
fieldImageCollectionValueChanged,
fieldImageValueChanged,
fieldStringCollectionValueChanged,
} from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import {
type FieldIdentifier,
isImageFieldCollectionInputInstance,
isStringFieldCollectionInputInstance,
} from 'features/nodes/types/field';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { uniqBy } from 'lodash-es';
@ -124,6 +132,76 @@ export const removeImageFromNodeImageFieldCollectionAction = (arg: {
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
};
export const addStringToNodeStringFieldCollectionAction = (arg: {
value: string;
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { value, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isStringFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to add strings to a non-string field collection');
return;
}
const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
fieldValue.push(value);
dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue }));
};
export const removeStringFromNodeStringFieldCollectionAction = (arg: {
index: number;
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { index, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isStringFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to remove string to a non-string field collection');
return;
}
const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
fieldValue.splice(index, 1);
dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue }));
};
export const changeStringOnNodeStringFieldCollectionAction = (arg: {
index: number;
value: string;
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { index, value, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isStringFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to add strings to a non-string field collection');
return;
}
const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
fieldValue.splice(index, 1, value);
dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue }));
};
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
const { imageDTO, dispatch } = arg;
dispatch(imageToCompareChanged(imageDTO));

View File

@ -1,5 +1,6 @@
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import {
@ -49,6 +50,8 @@ import {
isSDXLRefinerModelFieldInputTemplate,
isSpandrelImageToImageModelFieldInputInstance,
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
@ -94,6 +97,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
return <StringFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@ -1,5 +1,6 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldBoardValueChanged } from 'features/nodes/store/nodesSlice';
import type { BoardFieldInputInstance, BoardFieldInputTemplate } from 'features/nodes/types/field';
@ -17,17 +18,18 @@ const BoardFieldInputComponent = (props: FieldComponentProps<BoardFieldInputInst
{ include_archived: true },
{
selectFromResult: ({ data }) => {
const options: ComboboxOption[] = [
{
label: 'None',
value: 'none',
},
].concat(
(data ?? []).map(({ board_id, board_name }) => ({
label: board_name,
value: board_id,
}))
);
if (!data) {
return {
options: EMPTY_ARRAY,
hasBoards: false,
};
}
const options: ComboboxOption[] = data.map(({ board_id, board_name, image_count }) => ({
label: `${board_name} (${image_count})`,
value: board_id,
}));
return {
options,
hasBoards: options.length > 1,
@ -45,14 +47,14 @@ const BoardFieldInputComponent = (props: FieldComponentProps<BoardFieldInputInst
fieldBoardValueChanged({
nodeId,
fieldName: field.name,
value: v.value !== 'none' ? { board_id: v.value } : undefined,
value: { board_id: v.value },
})
);
},
[dispatch, field.name, nodeId]
);
const value = useMemo(() => options.find((o) => o.value === field.value?.board_id), [options, field.value]);
const value = useMemo(() => options.find((o) => o.value === field.value?.board_id) ?? null, [options, field.value]);
const noOptionsMessage = useCallback(() => t('boards.noMatching'), [t]);

View File

@ -90,7 +90,7 @@ export const ImageFieldCollectionInputComponent = memo(
isError={isInvalid}
onUpload={onUpload}
fontSize={24}
variant="outline"
variant="ghost"
/>
)}
{field.value && field.value.length > 0 && (

View File

@ -0,0 +1,168 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Grid, GridItem, IconButton, Textarea } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import {
addStringToNodeStringFieldCollectionAction,
changeStringOnNodeStringFieldCollectionAction,
removeStringFromNodeStringFieldCollectionAction,
} from 'features/imageActions/actions';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import type {
StringFieldCollectionInputInstance,
StringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold, PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const sx = {
borderWidth: 1,
'&[data-error=true]': {
borderColor: 'error.500',
borderStyle: 'solid',
},
} satisfies SystemStyleObject;
export const StringFieldCollectionInputComponent = memo(
(props: FieldComponentProps<StringFieldCollectionInputInstance, StringFieldCollectionInputTemplate>) => {
const { t } = useTranslation();
const { nodeId, field } = props;
const store = useAppStore();
const isInvalid = useFieldIsInvalid(nodeId, field.name);
const onRemoveString = useCallback(
(index: number) => {
removeStringFromNodeStringFieldCollectionAction({
index,
fieldIdentifier: { nodeId, fieldName: field.name },
dispatch: store.dispatch,
getState: store.getState,
});
},
[field.name, nodeId, store.dispatch, store.getState]
);
const onChangeString = useCallback(
(index: number, value: string) => {
changeStringOnNodeStringFieldCollectionAction({
index,
value,
fieldIdentifier: { nodeId, fieldName: field.name },
dispatch: store.dispatch,
getState: store.getState,
});
},
[field.name, nodeId, store.dispatch, store.getState]
);
const onAddString = useCallback(() => {
addStringToNodeStringFieldCollectionAction({
value: '',
fieldIdentifier: { nodeId, fieldName: field.name },
dispatch: store.dispatch,
getState: store.getState,
});
}, [field.name, nodeId, store]);
return (
<Flex
className="nodrag"
position="relative"
w="full"
h="full"
maxH={64}
alignItems="stretch"
justifyContent="center"
>
{(!field.value || field.value.length === 0) && (
<Box w="full" sx={sx} data-error={isInvalid} borderRadius="base">
<IconButton
w="full"
onClick={onAddString}
aria-label="Add Item"
icon={<PiPlusBold />}
variant="ghost"
size="sm"
/>
</Box>
)}
{field.value && field.value.length > 0 && (
<Box w="full" h="auto" p={1} sx={sx} data-error={isInvalid} borderRadius="base">
<OverlayScrollbarsComponent
className="nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Grid w="full" h="full" templateColumns="repeat(1, 1fr)" gap={1}>
<IconButton
onClick={onAddString}
aria-label="Add Item"
icon={<PiPlusBold />}
variant="ghost"
size="sm"
/>
{field.value.map((value, index) => (
<GridItem key={index} position="relative" className="nodrag">
<StringListItemContent
value={value}
index={index}
onRemoveString={onRemoveString}
onChangeString={onChangeString}
/>
</GridItem>
))}
</Grid>
</OverlayScrollbarsComponent>
</Box>
)}
</Flex>
);
}
);
StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent';
type StringListItemContentProps = {
value: string;
index: number;
onRemoveString: (index: number) => void;
onChangeString: (index: number, value: string) => void;
};
const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => {
const { t } = useTranslation();
const onClickRemove = useCallback(() => {
onRemoveString(index);
}, [index, onRemoveString]);
const onChange = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
onChangeString(index, e.target.value);
},
[index, onChangeString]
);
return (
<Flex alignItems="center" gap={1}>
<Textarea size="xs" resize="none" value={value} onChange={onChange} />
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
onClick={onClickRemove}
icon={<PiXBold />}
aria-label={t('common.remove')}
tooltip={t('common.remove')}
/>
</Flex>
);
});
StringListItemContent.displayName = 'StringListItemContent';

View File

@ -3,7 +3,12 @@ import { useAppSelector } from 'app/store/storeHooks';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import {
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { useMemo } from 'react';
export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
@ -46,6 +51,18 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
}
}
// Else special handling for individual field types
if (isStringFieldCollectionInputInstance(field) && isStringFieldCollectionInputTemplate(template)) {
// String collections may have min or max item counts
if (template.minItems !== undefined && field.value.length < template.minItems) {
return true;
}
if (template.maxItems !== undefined && field.value.length > template.maxItems) {
return true;
}
}
// Field looks OK
return false;
});

View File

@ -27,6 +27,7 @@ import type {
SDXLRefinerModelFieldValue,
SpandrelImageToImageModelFieldValue,
StatefulFieldValue,
StringFieldCollectionValue,
StringFieldValue,
T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
@ -54,6 +55,7 @@ import {
zSDXLRefinerModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue,
zStringFieldCollectionValue,
zStringFieldValue,
zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
@ -309,6 +311,9 @@ export const nodesSlice = createSlice({
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
fieldValueReducer(state, action, zStringFieldValue);
},
fieldStringCollectionValueChanged: (state, action: FieldValueAction<StringFieldCollectionValue>) => {
fieldValueReducer(state, action, zStringFieldCollectionValue);
},
fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => {
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
},
@ -433,6 +438,7 @@ export const {
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldStringCollectionValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
@ -543,6 +549,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldStringCollectionValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,

View File

@ -86,6 +86,15 @@ const zStringFieldType = zFieldTypeBase.extend({
name: z.literal('StringField'),
originalType: zStatelessFieldType.optional(),
});
const zStringCollectionFieldType = z.object({
name: z.literal('StringField'),
cardinality: z.literal(COLLECTION),
originalType: zStatelessFieldType.optional(),
});
export const isStringCollectionFieldType = (
fieldType: FieldType
): fieldType is z.infer<typeof zStringCollectionFieldType> => zStringCollectionFieldType.safeParse(fieldType).success;
const zBooleanFieldType = zFieldTypeBase.extend({
name: z.literal('BooleanField'),
originalType: zStatelessFieldType.optional(),
@ -310,6 +319,52 @@ const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStringFieldType,
});
// #region StringField Collection
export const zStringFieldCollectionValue = z.array(zStringFieldValue).optional();
const zStringFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zStringFieldCollectionValue,
});
const zStringFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
type: zStringCollectionFieldType,
originalType: zFieldType.optional(),
default: zStringFieldCollectionValue,
maxLength: z.number().int().gte(0).optional(),
minLength: z.number().int().gte(0).optional(),
maxItems: z.number().int().gte(0).optional(),
minItems: z.number().int().gte(0).optional(),
})
.refine(
(val) => {
if (val.maxLength !== undefined && val.minLength !== undefined) {
return val.maxLength >= val.minLength;
}
return true;
},
{ message: 'maxLength must be greater than or equal to minLength' }
)
.refine(
(val) => {
if (val.maxItems !== undefined && val.minItems !== undefined) {
return val.maxItems >= val.minItems;
}
return true;
},
{ message: 'maxItems must be greater than or equal to minItems' }
);
const zStringFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStringCollectionFieldType,
});
export type StringFieldCollectionValue = z.infer<typeof zStringFieldCollectionValue>;
export type StringFieldCollectionInputInstance = z.infer<typeof zStringFieldCollectionInputInstance>;
export type StringFieldCollectionInputTemplate = z.infer<typeof zStringFieldCollectionInputTemplate>;
export const isStringFieldCollectionInputInstance = (val: unknown): val is StringFieldCollectionInputInstance =>
zStringFieldCollectionInputInstance.safeParse(val).success;
export const isStringFieldCollectionInputTemplate = (val: unknown): val is StringFieldCollectionInputTemplate =>
zStringFieldCollectionInputTemplate.safeParse(val).success;
// #endregion
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
@ -409,7 +464,7 @@ const zImageFieldCollectionInputTemplate = zFieldInputTemplateBase
}
return true;
},
{ message: 'maxLength must be greater than or equal to minLength' }
{ message: 'maxItems must be greater than or equal to minItems' }
);
const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
@ -937,6 +992,7 @@ export const zStatefulFieldValue = z.union([
zIntegerFieldValue,
zFloatFieldValue,
zStringFieldValue,
zStringFieldCollectionValue,
zBooleanFieldValue,
zEnumFieldValue,
zImageFieldValue,
@ -973,6 +1029,7 @@ const zStatefulFieldInputInstance = z.union([
zIntegerFieldInputInstance,
zFloatFieldInputInstance,
zStringFieldInputInstance,
zStringFieldCollectionInputInstance,
zBooleanFieldInputInstance,
zEnumFieldInputInstance,
zImageFieldInputInstance,
@ -1008,6 +1065,7 @@ const zStatefulFieldInputTemplate = z.union([
zIntegerFieldInputTemplate,
zFloatFieldInputTemplate,
zStringFieldInputTemplate,
zStringFieldCollectionInputTemplate,
zBooleanFieldInputTemplate,
zEnumFieldInputTemplate,
zImageFieldInputTemplate,
@ -1046,6 +1104,7 @@ const zStatefulFieldOutputTemplate = z.union([
zIntegerFieldOutputTemplate,
zFloatFieldOutputTemplate,
zStringFieldOutputTemplate,
zStringFieldCollectionOutputTemplate,
zBooleanFieldOutputTemplate,
zEnumFieldOutputTemplate,
zImageFieldOutputTemplate,

View File

@ -27,12 +27,17 @@ import type {
SpandrelImageToImageModelFieldInputTemplate,
StatefulFieldType,
StatelessFieldInputTemplate,
StringFieldCollectionInputTemplate,
StringFieldInputTemplate,
T2IAdapterModelFieldInputTemplate,
T5EncoderModelFieldInputTemplate,
VAEModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { isImageCollectionFieldType, isStatefulFieldType } from 'features/nodes/types/field';
import {
isImageCollectionFieldType,
isStatefulFieldType,
isStringCollectionFieldType,
} from 'features/nodes/types/field';
import type { InvocationFieldSchema } from 'features/nodes/types/openapi';
import { isSchemaObject } from 'features/nodes/types/openapi';
import { t } from 'i18next';
@ -132,6 +137,36 @@ const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputT
return template;
};
const buildStringFieldCollectionInputTemplate: FieldInputTemplateBuilder<StringFieldCollectionInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: StringFieldCollectionInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
};
if (schemaObject.minLength !== undefined) {
template.minLength = schemaObject.minLength;
}
if (schemaObject.maxLength !== undefined) {
template.maxLength = schemaObject.maxLength;
}
if (schemaObject.minItems !== undefined) {
template.minItems = schemaObject.minItems;
}
if (schemaObject.maxItems !== undefined) {
template.maxItems = schemaObject.maxItems;
}
return template;
};
const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInputTemplate> = ({
schemaObject,
baseField,
@ -553,12 +588,18 @@ export const buildFieldInputTemplate = (
if (isStatefulFieldType(fieldType)) {
if (isImageCollectionFieldType(fieldType)) {
fieldType;
return buildImageFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else if (isStringCollectionFieldType(fieldType)) {
fieldType;
return buildStringFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else {
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
const template = builder({

View File

@ -11,7 +11,12 @@ import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { NodesState, Templates } from 'features/nodes/store/types';
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import {
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
@ -123,6 +128,37 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
});
return;
}
} else if (
field.value &&
isStringFieldCollectionInputInstance(field) &&
isStringFieldCollectionInputTemplate(fieldTemplate)
) {
// String collections may have min or max items to validate
// TODO(psyche): generalize this to other collection types
if (fieldTemplate.minItems !== undefined && fieldTemplate.minItems > 0 && field.value.length === 0) {
reasons.push({ content: i18n.t('parameters.invoke.collectionEmpty', baseTKeyOptions) });
return;
}
if (fieldTemplate.minItems !== undefined && field.value.length < fieldTemplate.minItems) {
reasons.push({
content: i18n.t('parameters.invoke.collectionTooFewItems', {
...baseTKeyOptions,
size: field.value.length,
minItems: fieldTemplate.minItems,
}),
});
return;
}
if (fieldTemplate.maxItems !== undefined && field.value.length > fieldTemplate.maxItems) {
reasons.push({
content: i18n.t('parameters.invoke.collectionTooManyItems', {
...baseTKeyOptions,
size: field.value.length,
maxItems: fieldTemplate.maxItems,
}),
});
return;
}
}
});
});