This commit is contained in:
Mary Hipp 2024-11-06 16:30:37 -05:00
parent 28864f6d7f
commit 9bd1f4a4f4
20 changed files with 103 additions and 37 deletions

View File

@ -53,6 +53,7 @@ class BulkDownloadSubscriptionEvent(BaseModel):
bulk_download_id: str
class ModelLoadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io model loading room.
This is a pydantic model to ensure the data is in the correct format."""

View File

@ -649,7 +649,9 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
return DepthAnythingPipeline(depth_anything_pipeline)
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], queue_id=self._context.util.get_queue_id(), loader=load_depth_anything
source=DEPTH_ANYTHING_MODELS[self.model_size],
queue_id=self._context.util.get_queue_id(),
loader=load_depth_anything,
) as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)

View File

@ -435,7 +435,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model, context.util.get_queue_id()))
control_model = exit_stack.enter_context(
context.models.load(control_info.control_model, context.util.get_queue_id())
)
assert isinstance(control_model, ControlNetModel)
control_image_field = control_info.image
@ -492,7 +494,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
for control_info in control_list:
model = exit_stack.enter_context(context.models.load(control_info.control_model, context.util.get_queue_id()))
model = exit_stack.enter_context(
context.models.load(control_info.control_model, context.util.get_queue_id())
)
ext_manager.add_extension(
ControlNetExt(
model=model,
@ -545,9 +549,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
image_prompts = []
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id()) as ip_adapter_model:
with context.models.load(
single_ip_adapter.ip_adapter_model, context.util.get_queue_id()
) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model, context.util.get_queue_id())
image_encoder_model_info = context.models.load(
single_ip_adapter.image_encoder_model, context.util.get_queue_id()
)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@ -581,7 +589,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
ip_adapters, image_prompts, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id()))
ip_adapter_model = exit_stack.enter_context(
context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id())
)
mask_field = single_ip_adapter.mask
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
@ -621,7 +631,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model, context.util.get_queue_id())
t2i_adapter_loaded_model = context.models.load(
t2i_adapter_field.t2i_adapter_model, context.util.get_queue_id()
)
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.

View File

@ -35,7 +35,9 @@ class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithB
model_url = DEPTH_ANYTHING_MODELS[self.model_size]
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(model_url, context.util.get_queue_id(), DepthAnythingPipeline.load_model)
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), DepthAnythingPipeline.load_model
)
with loaded_model as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)

View File

@ -32,7 +32,7 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
onnx_det_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
)
loaded_session_pose = context.models.load_local_model(
onnx_pose_path, context.util.get_queue_id(),DWOpenposeDetector2.create_onnx_inference_session
onnx_pose_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
)
with loaded_session_det as session_det, loaded_session_pose as session_pose:

View File

@ -468,7 +468,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# minimize peak memory.
# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model, context.util.get_queue_id()) for controlnet in controlnets]
controlnet_models = [
context.models.load(controlnet.control_model, context.util.get_queue_id()) for controlnet in controlnets
]
# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
@ -590,7 +592,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
pos_images.append(pos_image)
neg_images.append(neg_image)
with context.models.load(ip_adapter_field.image_encoder_model, context.util.get_queue_id()) as image_encoder_model:
with context.models.load(
ip_adapter_field.image_encoder_model, context.util.get_queue_id()
) as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
@ -620,7 +624,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model, context.util.get_queue_id()))
ip_adapter_model = exit_stack.enter_context(
context.models.load(ip_adapter_field.ip_adapter_model, context.util.get_queue_id())
)
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
if ip_adapter_field.mask is not None:

View File

@ -94,7 +94,9 @@ class GroundingDinoInvocation(BaseInvocation):
labels = [label if label.endswith(".") else label + "." for label in labels]
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_IDS[self.model], queue_id=context.util.get_queue_id(), loader=GroundingDinoInvocation._load_grounding_dino
source=GROUNDING_DINO_MODEL_IDS[self.model],
queue_id=context.util.get_queue_id(),
loader=GroundingDinoInvocation._load_grounding_dino,
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

View File

@ -22,7 +22,9 @@ class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(HEDEdgeDetector.get_model_url(), context.util.get_queue_id(), HEDEdgeDetector.load_model)
loaded_model = context.models.load_remote_model(
HEDEdgeDetector.get_model_url(), context.util.get_queue_id(), HEDEdgeDetector.load_model
)
with loaded_model as model:
assert isinstance(model, ControlNetHED_Apache2)

View File

@ -23,7 +23,9 @@ class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
model_url = LineartEdgeDetector.get_model_url(self.coarse)
loaded_model = context.models.load_remote_model(model_url, context.util.get_queue_id(), LineartEdgeDetector.load_model)
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), LineartEdgeDetector.load_model
)
with loaded_model as model:
assert isinstance(model, Generator)

View File

@ -20,7 +20,9 @@ class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoar
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
model_url = LineartAnimeEdgeDetector.get_model_url()
loaded_model = context.models.load_remote_model(model_url, context.util.get_queue_id(), LineartAnimeEdgeDetector.load_model)
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), LineartAnimeEdgeDetector.load_model
)
with loaded_model as model:
assert isinstance(model, UnetGenerator)

View File

@ -28,7 +28,9 @@ class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(MLSDDetector.get_model_url(),context.util.get_queue_id(), MLSDDetector.load_model)
loaded_model = context.models.load_remote_model(
MLSDDetector.get_model_url(), context.util.get_queue_id(), MLSDDetector.load_model
)
with loaded_model as model:
assert isinstance(model, MobileV2_MLSD_Large)

View File

@ -20,7 +20,9 @@ class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(NormalMapDetector.get_model_url(), context.util.get_queue_id(), NormalMapDetector.load_model)
loaded_model = context.models.load_remote_model(
NormalMapDetector.get_model_url(), context.util.get_queue_id(), NormalMapDetector.load_model
)
with loaded_model as model:
assert isinstance(model, NNET)

View File

@ -22,7 +22,9 @@ class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(PIDINetDetector.get_model_url(),context.util.get_queue_id(), PIDINetDetector.load_model)
loaded_model = context.models.load_remote_model(
PIDINetDetector.get_model_url(), context.util.get_queue_id(), PIDINetDetector.load_model
)
with loaded_model as model:
assert isinstance(model, PiDiNet)

View File

@ -125,7 +125,9 @@ class SegmentAnythingInvocation(BaseInvocation):
with (
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],queue_id=context.util.get_queue_id(), loader=SegmentAnythingInvocation._load_sam_model
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],
queue_id=context.util.get_queue_id(),
loader=SegmentAnythingInvocation._load_sam_model,
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingPipeline)

View File

@ -131,15 +131,17 @@ class EventServiceBase:
# region Model loading
def emit_model_load_started(self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None) -> None:
def emit_model_load_started(
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
) -> None:
"""Emitted when a model load is started."""
self.dispatch(ModelLoadStartedEvent.build(config, queue_id, submodel_type))
def emit_model_load_complete(
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
) -> None:
"""Emitted when a model load is complete."""
self.dispatch(ModelLoadCompleteEvent.build(config, queue_id,submodel_type))
self.dispatch(ModelLoadCompleteEvent.build(config, queue_id, submodel_type))
# endregion

View File

@ -399,7 +399,9 @@ class ModelLoadStartedEvent(ModelLoadEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
def build(
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> "ModelLoadStartedEvent":
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
@ -413,13 +415,16 @@ class ModelLoadCompleteEvent(ModelLoadEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
def build(
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> "ModelLoadCompleteEvent":
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
class ModelEventBase(EventBase):
"""Base class for model events"""
@payload_schema.register
class ModelInstallDownloadStartedEvent(ModelEventBase):
"""Event model for model_install_download_started"""

View File

@ -14,7 +14,9 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model(self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.

View File

@ -49,7 +49,9 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader."""
return self._ram_cache
def load_model(self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.

View File

@ -351,7 +351,10 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.store.exists(identifier.key)
def load(
self, identifier: Union[str, "ModelIdentifierField"], queue_id: str, submodel_type: Optional[SubModelType] = None
self,
identifier: Union[str, "ModelIdentifierField"],
queue_id: str,
submodel_type: Optional[SubModelType] = None,
) -> LoadedModel:
"""Load a model.
@ -375,7 +378,12 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.load.load_model(model, queue_id, _submodel_type)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType,queue_id: str, submodel_type: Optional[SubModelType] = None
self,
name: str,
base: BaseModelType,
type: ModelType,
queue_id: str,
submodel_type: Optional[SubModelType] = None,
) -> LoadedModel:
"""Load a model by its attributes.
@ -472,7 +480,7 @@ class ModelsInterface(InvocationContextInterface):
def load_local_model(
self,
model_path: Path,
queue_id: str,
queue_id: str,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
@ -490,12 +498,14 @@ class ModelsInterface(InvocationContextInterface):
Returns:
A LoadedModelWithoutConfig object.
"""
return self._services.model_manager.load.load_model_from_path(model_path=model_path,queue_id=queue_id, loader=loader)
return self._services.model_manager.load.load_model_from_path(
model_path=model_path, queue_id=queue_id, loader=loader
)
def load_remote_model(
self,
source: str | AnyHttpUrl,
queue_id: str,
queue_id: str,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
@ -516,7 +526,9 @@ class ModelsInterface(InvocationContextInterface):
A LoadedModelWithoutConfig object.
"""
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
return self._services.model_manager.load.load_model_from_path(model_path=model_path, queue_id=queue_id, loader=loader)
return self._services.model_manager.load.load_model_from_path(
model_path=model_path, queue_id=queue_id, loader=loader
)
class ConfigInterface(InvocationContextInterface):
@ -544,7 +556,7 @@ class UtilInterface(InvocationContextInterface):
True if the current session has been canceled, False if not.
"""
return self._data.queue_item.queue_id
def is_canceled(self) -> bool:
"""Checks if the current session has been canceled.

View File

@ -66,10 +66,14 @@ def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) ->
def test_download_and_load(mock_context: InvocationContext) -> None:
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id())
loaded_model_1 = mock_context.models.load_remote_model(
"https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id()
)
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id())
loaded_model_2 = mock_context.models.load_remote_model(
"https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id()
)
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model # should be cached copy