mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
feat(nodes): add HEDEdgeDetectionInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.
This commit is contained in:
parent
ac9950bdbb
commit
1cffcc02a5
33
invokeai/app/invocations/hed.py
Normal file
33
invokeai/app/invocations/hed.py
Normal file
@ -0,0 +1,33 @@
|
||||
from builtins import bool
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.hed import ControlNetHED_Apache2, HEDEdgeDetector
|
||||
|
||||
|
||||
@invocation(
|
||||
"hed_edge_detection",
|
||||
title="HED Edge Detection",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Geneartes an edge map using the HED (softedge) model."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
loaded_model = context.models.load_remote_model(HEDEdgeDetector.get_model_url(), HEDEdgeDetector.load_model)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, ControlNetHED_Apache2)
|
||||
hed_processor = HEDEdgeDetector(model)
|
||||
edge_map = hed_processor.run(image=image, scribble=self.scribble)
|
||||
|
||||
image_dto = context.images.save(image=edge_map)
|
||||
return ImageOutput.build(image_dto)
|
@ -1,6 +1,9 @@
|
||||
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
|
||||
# Adapted from https://github.com/huggingface/controlnet_aux
|
||||
|
||||
import pathlib
|
||||
|
||||
import cv2
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
@ -140,3 +143,74 @@ class HEDProcessor:
|
||||
detected_map[detected_map < 255] = 0
|
||||
|
||||
return np_to_pil(detected_map)
|
||||
|
||||
|
||||
class HEDEdgeDetector:
|
||||
"""Simple wrapper around the HED model for detecting edges in an image."""
|
||||
|
||||
hf_repo_id = "lllyasviel/Annotators"
|
||||
hf_filename = "ControlNetHED.pth"
|
||||
|
||||
def __init__(self, model: ControlNetHED_Apache2):
|
||||
self.model = model
|
||||
|
||||
@classmethod
|
||||
def get_model_url(cls) -> str:
|
||||
"""Get the URL to download the model from the Hugging Face Hub."""
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path: pathlib.Path) -> ControlNetHED_Apache2:
|
||||
"""Load the model from a file."""
|
||||
model = ControlNetHED_Apache2()
|
||||
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
model.float().eval()
|
||||
return model
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def run(self, image: Image.Image, safe: bool = False, scribble: bool = False) -> Image.Image:
|
||||
"""Processes an image and returns the detected edges.
|
||||
|
||||
Args:
|
||||
image: The input image.
|
||||
safe: Whether to apply safe step to the detected edges.
|
||||
scribble: Whether to apply non-maximum suppression and Gaussian blur to the detected edges.
|
||||
|
||||
Returns:
|
||||
The detected edges.
|
||||
"""
|
||||
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
height, width, _channels = np_image.shape
|
||||
|
||||
with torch.no_grad():
|
||||
image_hed = torch.from_numpy(np_image.copy()).float().to(device)
|
||||
image_hed = rearrange(image_hed, "h w c -> 1 c h w")
|
||||
edges = self.model(image_hed)
|
||||
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
||||
edges = [cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) for e in edges]
|
||||
edges = np.stack(edges, axis=2)
|
||||
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
||||
if safe:
|
||||
edge = safe_step(edge)
|
||||
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
detected_map = edge
|
||||
|
||||
detected_map = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if scribble:
|
||||
detected_map = nms(detected_map, 127, 3.0)
|
||||
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
|
||||
detected_map[detected_map > 4] = 255
|
||||
detected_map[detected_map < 255] = 0
|
||||
|
||||
output = np_to_pil(detected_map)
|
||||
|
||||
return output
|
||||
|
Loading…
x
Reference in New Issue
Block a user