Add Sd3ModelLoaderInvocation.

This commit is contained in:
Ryan Dick 2024-10-24 00:17:19 +00:00 committed by Brandon
parent 9d3f5427b4
commit 009cdb714c
3 changed files with 67 additions and 0 deletions

View File

@ -133,6 +133,7 @@ class FieldDescriptions:
clip_embed_model = "CLIP Embed loader"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
mmditx = "MMDiTX"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
@ -140,6 +141,7 @@ class FieldDescriptions:
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"

View File

@ -0,0 +1,63 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
@invocation_output("sd3_model_loader_output")
class Sd3ModelLoaderOutput(BaseInvocationOutput):
"""SD3 base model loader output."""
mmditx: TransformerField = OutputField(description=FieldDescriptions.mmditx, title="MMDiTX")
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"sd3_model_loader",
title="SD3 Main Model",
tags=["model", "sd3"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3ModelLoaderInvocation(BaseInvocation):
"""Loads a SD3 base model, outputting its submodels."""
# TODO(ryand): Create a UIType.Sd3MainModelField to use here.
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sd3_model,
ui_type=UIType.MainModel,
input=Input.Direct,
)
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
model_key = self.model.key
if not context.models.exists(model_key):
raise ValueError(f"Unknown model: {model_key}")
mmditx = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder_l = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
clip_encoder_g = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
tokenizer_t5 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
t5_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
return Sd3ModelLoaderOutput(
mmditx=TransformerField(transformer=mmditx, loras=[]),
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
)

View File

@ -84,8 +84,10 @@ class SubModelType(str, Enum):
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"