mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
feat(nodes): add LineartEdgeDetectionInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.
This commit is contained in:
parent
cd2c2a7fde
commit
c5f3297841
34
invokeai/app/invocations/lineart.py
Normal file
34
invokeai/app/invocations/lineart.py
Normal file
@ -0,0 +1,34 @@
|
||||
from builtins import bool
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.lineart import Generator, LineartEdgeDetector
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_edge_detection",
|
||||
title="Lineart Edge Detection",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an edge map using the Lineart model."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
model_url = LineartEdgeDetector.get_model_url(self.coarse)
|
||||
loaded_model = context.models.load_remote_model(model_url, LineartEdgeDetector.load_model)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, Generator)
|
||||
detector = LineartEdgeDetector(model)
|
||||
edge_map = detector.run(image=image)
|
||||
|
||||
image_dto = context.images.save(image=edge_map)
|
||||
return ImageOutput.build(image_dto)
|
@ -1,6 +1,9 @@
|
||||
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
|
||||
|
||||
import pathlib
|
||||
|
||||
import cv2
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -156,3 +159,63 @@ class LineartProcessor:
|
||||
detected_map = 255 - detected_map
|
||||
|
||||
return np_to_pil(detected_map)
|
||||
|
||||
|
||||
class LineartEdgeDetector:
|
||||
"""Simple wrapper around the fine and coarse lineart models for detecting edges in an image."""
|
||||
|
||||
hf_repo_id = "lllyasviel/Annotators"
|
||||
hf_filename_fine = "sk_model.pth"
|
||||
hf_filename_coarse = "sk_model2.pth"
|
||||
|
||||
@classmethod
|
||||
def get_model_url(cls, coarse: bool = False) -> str:
|
||||
"""Get the URL to download the model from the Hugging Face Hub."""
|
||||
if coarse:
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename_coarse)
|
||||
else:
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename_fine)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path: pathlib.Path) -> Generator:
|
||||
"""Load the model from a file."""
|
||||
model = Generator(3, 1, 3)
|
||||
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
model.float().eval()
|
||||
return model
|
||||
|
||||
def __init__(self, model: Generator) -> None:
|
||||
self.model = model
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def run(self, image: Image.Image) -> Image.Image:
|
||||
"""Detects edges in the input image with the selected lineart model.
|
||||
|
||||
Args:
|
||||
input: The input image.
|
||||
coarse: Whether to use the coarse model.
|
||||
|
||||
Returns:
|
||||
The detected edges.
|
||||
"""
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
with torch.no_grad():
|
||||
np_image = torch.from_numpy(np_image).float().to(device)
|
||||
np_image = np_image / 255.0
|
||||
np_image = rearrange(np_image, "h w c -> 1 c h w")
|
||||
line = self.model(np_image)[0][0]
|
||||
|
||||
line = line.cpu().numpy()
|
||||
line = (line * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
detected_map = line
|
||||
|
||||
detected_map = 255 - detected_map
|
||||
|
||||
return np_to_pil(detected_map)
|
||||
|
Loading…
x
Reference in New Issue
Block a user