mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-08 11:57:36 +08:00
ruff
This commit is contained in:
parent
28864f6d7f
commit
9bd1f4a4f4
@ -53,6 +53,7 @@ class BulkDownloadSubscriptionEvent(BaseModel):
|
||||
|
||||
bulk_download_id: str
|
||||
|
||||
|
||||
class ModelLoadSubscriptionEvent(BaseModel):
|
||||
"""Event data for subscribing to the socket.io model loading room.
|
||||
This is a pydantic model to ensure the data is in the correct format."""
|
||||
|
@ -649,7 +649,9 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], queue_id=self._context.util.get_queue_id(), loader=load_depth_anything
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size],
|
||||
queue_id=self._context.util.get_queue_id(),
|
||||
loader=load_depth_anything,
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
@ -435,7 +435,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
controlnet_data: list[ControlNetData] = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model, context.util.get_queue_id()))
|
||||
control_model = exit_stack.enter_context(
|
||||
context.models.load(control_info.control_model, context.util.get_queue_id())
|
||||
)
|
||||
assert isinstance(control_model, ControlNetModel)
|
||||
|
||||
control_image_field = control_info.image
|
||||
@ -492,7 +494,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||
|
||||
for control_info in control_list:
|
||||
model = exit_stack.enter_context(context.models.load(control_info.control_model, context.util.get_queue_id()))
|
||||
model = exit_stack.enter_context(
|
||||
context.models.load(control_info.control_model, context.util.get_queue_id())
|
||||
)
|
||||
ext_manager.add_extension(
|
||||
ControlNetExt(
|
||||
model=model,
|
||||
@ -545,9 +549,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
||||
image_prompts = []
|
||||
for single_ip_adapter in ip_adapters:
|
||||
with context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id()) as ip_adapter_model:
|
||||
with context.models.load(
|
||||
single_ip_adapter.ip_adapter_model, context.util.get_queue_id()
|
||||
) as ip_adapter_model:
|
||||
assert isinstance(ip_adapter_model, IPAdapter)
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model, context.util.get_queue_id())
|
||||
image_encoder_model_info = context.models.load(
|
||||
single_ip_adapter.image_encoder_model, context.util.get_queue_id()
|
||||
)
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
@ -581,7 +589,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
|
||||
ip_adapters, image_prompts, strict=True
|
||||
):
|
||||
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id()))
|
||||
ip_adapter_model = exit_stack.enter_context(
|
||||
context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id())
|
||||
)
|
||||
|
||||
mask_field = single_ip_adapter.mask
|
||||
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
|
||||
@ -621,7 +631,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model, context.util.get_queue_id())
|
||||
t2i_adapter_loaded_model = context.models.load(
|
||||
t2i_adapter_field.t2i_adapter_model, context.util.get_queue_id()
|
||||
)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
|
@ -35,7 +35,9 @@ class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithB
|
||||
model_url = DEPTH_ANYTHING_MODELS[self.model_size]
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
loaded_model = context.models.load_remote_model(model_url, context.util.get_queue_id(), DepthAnythingPipeline.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
model_url, context.util.get_queue_id(), DepthAnythingPipeline.load_model
|
||||
)
|
||||
|
||||
with loaded_model as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
|
@ -32,7 +32,7 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
onnx_det_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
loaded_session_pose = context.models.load_local_model(
|
||||
onnx_pose_path, context.util.get_queue_id(),DWOpenposeDetector2.create_onnx_inference_session
|
||||
onnx_pose_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
|
||||
with loaded_session_det as session_det, loaded_session_pose as session_pose:
|
||||
|
@ -468,7 +468,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# minimize peak memory.
|
||||
|
||||
# First, load the ControlNet models so that we can determine the ControlNet types.
|
||||
controlnet_models = [context.models.load(controlnet.control_model, context.util.get_queue_id()) for controlnet in controlnets]
|
||||
controlnet_models = [
|
||||
context.models.load(controlnet.control_model, context.util.get_queue_id()) for controlnet in controlnets
|
||||
]
|
||||
|
||||
# Calculate the controlnet conditioning tensors.
|
||||
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
|
||||
@ -590,7 +592,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
pos_images.append(pos_image)
|
||||
neg_images.append(neg_image)
|
||||
|
||||
with context.models.load(ip_adapter_field.image_encoder_model, context.util.get_queue_id()) as image_encoder_model:
|
||||
with context.models.load(
|
||||
ip_adapter_field.image_encoder_model, context.util.get_queue_id()
|
||||
) as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
|
||||
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
|
||||
@ -620,7 +624,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
|
||||
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
|
||||
):
|
||||
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model, context.util.get_queue_id()))
|
||||
ip_adapter_model = exit_stack.enter_context(
|
||||
context.models.load(ip_adapter_field.ip_adapter_model, context.util.get_queue_id())
|
||||
)
|
||||
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
|
||||
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
|
||||
if ip_adapter_field.mask is not None:
|
||||
|
@ -94,7 +94,9 @@ class GroundingDinoInvocation(BaseInvocation):
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
with context.models.load_remote_model(
|
||||
source=GROUNDING_DINO_MODEL_IDS[self.model], queue_id=context.util.get_queue_id(), loader=GroundingDinoInvocation._load_grounding_dino
|
||||
source=GROUNDING_DINO_MODEL_IDS[self.model],
|
||||
queue_id=context.util.get_queue_id(),
|
||||
loader=GroundingDinoInvocation._load_grounding_dino,
|
||||
) as detector:
|
||||
assert isinstance(detector, GroundingDinoPipeline)
|
||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||
|
@ -22,7 +22,9 @@ class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(HEDEdgeDetector.get_model_url(), context.util.get_queue_id(), HEDEdgeDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
HEDEdgeDetector.get_model_url(), context.util.get_queue_id(), HEDEdgeDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, ControlNetHED_Apache2)
|
||||
|
@ -23,7 +23,9 @@ class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
model_url = LineartEdgeDetector.get_model_url(self.coarse)
|
||||
loaded_model = context.models.load_remote_model(model_url, context.util.get_queue_id(), LineartEdgeDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
model_url, context.util.get_queue_id(), LineartEdgeDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, Generator)
|
||||
|
@ -20,7 +20,9 @@ class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoar
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
model_url = LineartAnimeEdgeDetector.get_model_url()
|
||||
loaded_model = context.models.load_remote_model(model_url, context.util.get_queue_id(), LineartAnimeEdgeDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
model_url, context.util.get_queue_id(), LineartAnimeEdgeDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, UnetGenerator)
|
||||
|
@ -28,7 +28,9 @@ class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(MLSDDetector.get_model_url(),context.util.get_queue_id(), MLSDDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
MLSDDetector.get_model_url(), context.util.get_queue_id(), MLSDDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, MobileV2_MLSD_Large)
|
||||
|
@ -20,7 +20,9 @@ class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(NormalMapDetector.get_model_url(), context.util.get_queue_id(), NormalMapDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
NormalMapDetector.get_model_url(), context.util.get_queue_id(), NormalMapDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, NNET)
|
||||
|
@ -22,7 +22,9 @@ class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(PIDINetDetector.get_model_url(),context.util.get_queue_id(), PIDINetDetector.load_model)
|
||||
loaded_model = context.models.load_remote_model(
|
||||
PIDINetDetector.get_model_url(), context.util.get_queue_id(), PIDINetDetector.load_model
|
||||
)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, PiDiNet)
|
||||
|
@ -125,7 +125,9 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
|
||||
with (
|
||||
context.models.load_remote_model(
|
||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],queue_id=context.util.get_queue_id(), loader=SegmentAnythingInvocation._load_sam_model
|
||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],
|
||||
queue_id=context.util.get_queue_id(),
|
||||
loader=SegmentAnythingInvocation._load_sam_model,
|
||||
) as sam_pipeline,
|
||||
):
|
||||
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||
|
@ -131,15 +131,17 @@ class EventServiceBase:
|
||||
|
||||
# region Model loading
|
||||
|
||||
def emit_model_load_started(self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None) -> None:
|
||||
def emit_model_load_started(
|
||||
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
|
||||
) -> None:
|
||||
"""Emitted when a model load is started."""
|
||||
self.dispatch(ModelLoadStartedEvent.build(config, queue_id, submodel_type))
|
||||
|
||||
def emit_model_load_complete(
|
||||
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
|
||||
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
|
||||
) -> None:
|
||||
"""Emitted when a model load is complete."""
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, queue_id,submodel_type))
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, queue_id, submodel_type))
|
||||
|
||||
# endregion
|
||||
|
||||
|
@ -399,7 +399,9 @@ class ModelLoadStartedEvent(ModelLoadEventBase):
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
|
||||
def build(
|
||||
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> "ModelLoadStartedEvent":
|
||||
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
|
||||
|
||||
|
||||
@ -413,13 +415,16 @@ class ModelLoadCompleteEvent(ModelLoadEventBase):
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
|
||||
def build(
|
||||
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> "ModelLoadCompleteEvent":
|
||||
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
|
||||
|
||||
|
||||
class ModelEventBase(EventBase):
|
||||
"""Base class for model events"""
|
||||
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallDownloadStartedEvent(ModelEventBase):
|
||||
"""Event model for model_install_download_started"""
|
||||
|
@ -14,7 +14,9 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
|
@ -49,7 +49,9 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
|
@ -351,7 +351,10 @@ class ModelsInterface(InvocationContextInterface):
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(
|
||||
self, identifier: Union[str, "ModelIdentifierField"], queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
self,
|
||||
identifier: Union[str, "ModelIdentifierField"],
|
||||
queue_id: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> LoadedModel:
|
||||
"""Load a model.
|
||||
|
||||
@ -375,7 +378,12 @@ class ModelsInterface(InvocationContextInterface):
|
||||
return self._services.model_manager.load.load_model(model, queue_id, _submodel_type)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType,queue_id: str, submodel_type: Optional[SubModelType] = None
|
||||
self,
|
||||
name: str,
|
||||
base: BaseModelType,
|
||||
type: ModelType,
|
||||
queue_id: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> LoadedModel:
|
||||
"""Load a model by its attributes.
|
||||
|
||||
@ -472,7 +480,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
def load_local_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
queue_id: str,
|
||||
queue_id: str,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
@ -490,12 +498,14 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path,queue_id=queue_id, loader=loader)
|
||||
return self._services.model_manager.load.load_model_from_path(
|
||||
model_path=model_path, queue_id=queue_id, loader=loader
|
||||
)
|
||||
|
||||
def load_remote_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
queue_id: str,
|
||||
queue_id: str,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
@ -516,7 +526,9 @@ class ModelsInterface(InvocationContextInterface):
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, queue_id=queue_id, loader=loader)
|
||||
return self._services.model_manager.load.load_model_from_path(
|
||||
model_path=model_path, queue_id=queue_id, loader=loader
|
||||
)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
@ -544,7 +556,7 @@ class UtilInterface(InvocationContextInterface):
|
||||
True if the current session has been canceled, False if not.
|
||||
"""
|
||||
return self._data.queue_item.queue_id
|
||||
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""Checks if the current session has been canceled.
|
||||
|
||||
|
@ -66,10 +66,14 @@ def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) ->
|
||||
|
||||
|
||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id())
|
||||
loaded_model_1 = mock_context.models.load_remote_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id()
|
||||
)
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id())
|
||||
loaded_model_2 = mock_context.models.load_remote_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors", mock_context.util.get_queue_id()
|
||||
)
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user