refactor(nodes): model identifiers

- All models are identified by a key and optionally a submodel type via new model `ModelField`. Previously, a few model types had their own class, but not all of them. This inconsistency just added complexity without any benefit.
- Update all invocation to use the new format.
- In the node API, models are loaded by key or an instance of `ModelField` as a convenience.
- Add an enriched model schema for metadata. It includes key, hash, name, base and type.
This commit is contained in:
psychedelicious 2024-03-06 19:37:15 +11:00
parent afd9ae7712
commit 528ac5dd25
15 changed files with 229 additions and 288 deletions

View File

@ -54,16 +54,16 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
@ -133,10 +133,10 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_info = context.models.load(lora.lora)
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)

View File

@ -34,6 +34,7 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -51,15 +52,9 @@ CONTROLNET_RESIZE_VALUES = Literal[
]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
key: str = Field(description="Model config record key for the ControlNet model")
class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
control_model: ModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -95,7 +90,7 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)

View File

@ -228,7 +228,7 @@ class ConditioningField(BaseModel):
# endregion
class MetadataField(RootModel):
class MetadataField(RootModel[dict[str, Any]]):
"""
Pydantic model for metadata with custom root of type dict[str, Any].
Metadata is stored without a strict schema.

View File

@ -11,25 +11,17 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
# LS: Consider moving these two classes into model.py
class IPAdapterModelField(BaseModel):
key: str = Field(description="Key to the IP-Adapter model")
class CLIPVisionModelField(BaseModel):
key: str = Field(description="Key to the CLIP Vision image encoder model")
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
class IPAdapterField(BaseModel):
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
@ -62,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = InputField(
ip_adapter_model: ModelField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
)
@ -90,18 +82,18 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=image_encoder_model,
image_encoder_model=ModelField(key=image_encoder_models[0].key),
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,

View File

@ -26,6 +26,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
@ -75,7 +76,7 @@ from .baseinvocation import (
invocation_output,
)
from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField
from .model import ModelField, UNetField, VaeField
if choose_torch_device() == torch.device("mps"):
from torch import mps
@ -153,7 +154,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image_tensor is not None:
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.models.load(self.vae.vae)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
@ -244,12 +245,12 @@ class CreateGradientMaskInvocation(BaseInvocation):
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelInfo,
scheduler_info: ModelField,
scheduler_name: str,
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@ -461,7 +462,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
# control_models.append(control_model)
control_image_field = control_info.image
@ -523,11 +524,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
context.models.load(single_ip_adapter.ip_adapter_model)
)
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@ -538,6 +538,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
@ -577,8 +578,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@ -731,12 +732,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(**self.unet.unet.model_dump())
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
@ -841,8 +843,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device)
@ -1064,7 +1066,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@ -8,7 +8,10 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.controlnet_image_processors import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@ -17,10 +20,8 @@ from invokeai.app.invocations.fields import (
OutputField,
UIType,
)
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from ...version import __version__
@ -30,10 +31,20 @@ class MetadataItemField(BaseModel):
value: Any = Field(description=FieldDescriptions.metadata_item_value)
class ModelMetadataField(BaseModel):
"""Model Metadata Field"""
key: str
hash: str
name: str
base: BaseModelType
type: ModelType
class LoRAMetadataField(BaseModel):
"""LoRA Metadata Field"""
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
weight: float = Field(description=FieldDescriptions.lora_weight)
@ -41,7 +52,7 @@ class IPAdapterMetadataField(BaseModel):
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = Field(
ip_adapter_model: ModelMetadataField = Field(
description="The IP-Adapter model.",
)
weight: Union[float, list[float]] = Field(
@ -51,6 +62,33 @@ class IPAdapterMetadataField(BaseModel):
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
class T2IAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
class ControlNetMetadataField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@invocation_output("metadata_item_output")
class MetadataItemOutput(BaseInvocationOutput):
"""Metadata Item Output"""
@ -140,14 +178,14 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The number of skipped CLIP layers",
)
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlField]] = InputField(
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
default=None, description="The ControlNets used for inference"
)
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
default=None, description="The IP Adapters used for inference"
)
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
default=None, description="The IP Adapters used for inference"
)
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
@ -159,7 +197,7 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The name of the initial image",
)
vae: Optional[VAEModelField] = InputField(
vae: Optional[ModelMetadataField] = InputField(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
@ -190,7 +228,7 @@ class CoreMetadataInvocation(BaseInvocation):
)
# SDXL Refiner
refiner_model: Optional[MainModelField] = InputField(
refiner_model: Optional[ModelMetadataField] = InputField(
default=None,
description="The SDXL Refiner model used",
)
@ -222,10 +260,9 @@ class CoreMetadataInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> MetadataOutput:
"""Collects and outputs a CoreMetadata object"""
return MetadataOutput(
metadata=MetadataField.model_validate(
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
)
)
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
as_dict["app_version"] = __version__
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
model_config = ConfigDict(extra="allow")

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import SubModelType
from ...backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -16,33 +16,34 @@ from .baseinvocation import (
)
class ModelInfo(BaseModel):
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
class ModelField(BaseModel):
key: str = Field(description="Key of the model")
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
class LoRAField(BaseModel):
lora: ModelField = Field(description="Info to load lora model")
weight: float = Field(description="Weight to apply to lora model")
class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
unet: ModelField = Field(description="Info to load unet submodel")
scheduler: ModelField = Field(description="Info to load scheduler submodel")
loras: List[LoRAField] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
loras: List[LoRAField] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
vae: ModelField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -74,18 +75,6 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass
class MainModelField(BaseModel):
"""Main model field"""
key: str = Field(description="Model key")
class LoRAModelField(BaseModel):
"""LoRA model field"""
key: str = Field(description="LoRA model key")
@invocation(
"main_model_loader",
title="Main Model",
@ -96,46 +85,24 @@ class LoRAModelField(BaseModel):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
key = self.model.key
# TODO: not found exceptions
if not context.models.exists(key):
raise Exception(f"Unknown model {key}")
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=key,
submodel_type=SubModelType.VAE,
),
),
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip=ClipField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
vae=VaeField(vae=vae),
)
@ -151,7 +118,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -167,38 +134,33 @@ class LoraLoaderInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
output = LoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
@ -225,7 +187,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -247,51 +209,45 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip2')
output = SDXLLoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip2 is not None:
output.clip2 = copy.deepcopy(self.clip2)
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
@ -299,17 +255,11 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
return output
class VAEModelField(BaseModel):
"""Vae model field"""
key: str = Field(description="Model's key")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: VAEModelField = InputField(
vae_model: ModelField = InputField(
description=FieldDescriptions.vae_model,
input=Input.Direct,
title="VAE",
@ -321,7 +271,7 @@ class VaeLoaderInvocation(BaseInvocation):
if not context.models.exists(key):
raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
return VAEOutput(vae=VaeField(vae=self.vae_model))
@invocation_output("seamless_output")

View File

@ -8,7 +8,7 @@ from .baseinvocation import (
invocation,
invocation_output,
)
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
from .model import ClipField, ModelField, UNetField, VaeField
@invocation_output("sdxl_model_loader_output")
@ -34,7 +34,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: MainModelField = InputField(
model: ModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
@ -46,48 +46,19 @@ class SDXLModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.VAE,
),
),
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip=ClipField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
clip2=ClipField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VaeField(vae=vae),
)
@ -101,10 +72,8 @@ class SDXLModelLoaderInvocation(BaseInvocation):
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,
ui_type=UIType.SDXLRefinerModel,
model: ModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
)
# TODO: precision?
@ -115,34 +84,14 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.VAE,
),
),
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip2=ClipField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VaeField(vae=vae),
)

View File

@ -10,17 +10,14 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
class T2IAdapterModelField(BaseModel):
key: str = Field(description="Model record key for the T2I-Adapter model")
class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
@ -55,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation):
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = InputField(
t2i_adapter_model: ModelField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,

View File

@ -1,7 +1,6 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
import torch
from typing_extensions import Self

View File

@ -130,6 +130,17 @@ class ModelRecordServiceBase(ABC):
"""
pass
@abstractmethod
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param hash: Hash of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default

View File

@ -203,6 +203,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.

View File

@ -1,7 +1,7 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
from PIL.Image import Image
from torch import Tensor
@ -13,15 +13,16 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
"""
@ -299,22 +300,25 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
def exists(self, key: str) -> bool:
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
"""Checks if a model exists.
Args:
key: The key of the model.
identifier: The key or ModelField representing the model.
Returns:
True if the model exists, False if not.
"""
return self._services.model_manager.store.exists(key)
if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier)
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
return self._services.model_manager.store.exists(identifier.key)
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Loads a model.
Args:
key: The key of the model.
identifier: The key or ModelField representing the model.
submodel_type: The submodel of the model to get.
Returns:
@ -324,9 +328,13 @@ class ModelsInterface(InvocationContextInterface):
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
return self._services.model_manager.load_model_by_key(
key=key, submodel_type=submodel_type, context_data=self._data
)
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@ -343,35 +351,29 @@ class ModelsInterface(InvocationContextInterface):
Returns:
An object representing the loaded model.
"""
return self._services.model_manager.load_model_by_attr(
model_name=name,
base_model=base,
model_type=type,
submodel=submodel_type,
context_data=self._data,
)
def get_config(self, key: str) -> AnyModelConfig:
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
if len(configs) == 0:
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
"""Gets a model's config.
Args:
key: The key of the model.
identifier: The key or ModelField representing the model.
Returns:
The model's config.
"""
return self._services.model_manager.store.get_model(key=key)
if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier)
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""Gets a model's metadata, if it has any.
Args:
key: The key of the model.
Returns:
The model's metadata, if it has any.
"""
return self._services.model_manager.store.get_metadata(key=key)
return self._services.model_manager.store.get_model(identifier.key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path.

View File

@ -22,7 +22,7 @@ def generate_ti_list(
for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(key=name_or_key)
loaded_model = context.models.load(name_or_key)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base

View File

@ -35,17 +35,13 @@ from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOut
from invokeai.app.invocations.model import (
ClipField,
CLIPOutput,
LoraInfo,
LoraLoaderOutput,
LoRAModelField,
MainModelField,
ModelInfo,
ModelField,
ModelLoaderOutput,
SDXLLoraLoaderOutput,
UNetField,
UNetOutput,
VaeField,
VAEModelField,
VAEOutput,
)
from invokeai.app.invocations.primitives import (
@ -73,8 +69,8 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.model_management.model_manager import LoadedModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@ -118,14 +114,10 @@ __all__ = [
"MetadataItemOutput",
"MetadataOutput",
# invokeai.app.invocations.model
"ModelInfo",
"LoraInfo",
"ModelField",
"UNetField",
"ClipField",
"VaeField",
"MainModelField",
"LoRAModelField",
"VAEModelField",
"UNetOutput",
"VAEOutput",
"CLIPOutput",
@ -166,7 +158,7 @@ __all__ = [
# invokeai.app.services.config.config_default
"InvokeAIAppConfig",
# invokeai.backend.model_management.model_manager
"LoadedModelInfo",
"LoadedModel",
# invokeai.backend.model_management.models.base
"BaseModelType",
"ModelType",