feat(nodes): add LineartAnimeEdgeDetectionInvocation

Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.
This commit is contained in:
psychedelicious 2024-09-10 20:57:16 +10:00 committed by Kent Keirsey
parent 1cffcc02a5
commit cd2c2a7fde
2 changed files with 95 additions and 0 deletions

View File

@ -0,0 +1,31 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.lineart_anime import LineartAnimeEdgeDetector, UnetGenerator
@invocation(
"lineart_anime_edge_detection",
title="Lineart Anime Edge Detection",
tags=["controlnet", "lineart"],
category="controlnet",
version="1.0.0",
)
class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Geneartes an edge map using the Lineart model."""
image: ImageField = InputField(description="The image to process")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
model_url = LineartAnimeEdgeDetector.get_model_url()
loaded_model = context.models.load_remote_model(model_url, LineartAnimeEdgeDetector.load_model)
with loaded_model as model:
assert isinstance(model, UnetGenerator)
detector = LineartAnimeEdgeDetector(model)
edge_map = detector.run(image=image)
image_dto = context.images.save(image=edge_map)
return ImageOutput.build(image_dto)

View File

@ -1,9 +1,11 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
import functools
import pathlib
from typing import Optional
import cv2
import huggingface_hub
import numpy as np
import torch
import torch.nn as nn
@ -201,3 +203,65 @@ class LineartAnimeProcessor:
detected_map = 255 - detected_map
return np_to_pil(detected_map)
class LineartAnimeEdgeDetector:
"""Simple wrapper around the Lineart Anime model for detecting edges in an image."""
hf_repo_id = "lllyasviel/Annotators"
hf_filename = "netG.pth"
@classmethod
def get_model_url(cls) -> str:
"""Get the URL to download the model from the Hugging Face Hub."""
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
@classmethod
def load_model(cls, model_path: pathlib.Path) -> UnetGenerator:
"""Load the model from a file."""
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
ckpt = torch.load(model_path)
for key in list(ckpt.keys()):
if "module." in key:
ckpt[key.replace("module.", "")] = ckpt[key]
del ckpt[key]
model.load_state_dict(ckpt)
model.eval()
return model
def __init__(self, model: UnetGenerator) -> None:
self.model = model
def to(self, device: torch.device):
self.model.to(device)
return self
def run(self, image: Image.Image) -> Image.Image:
"""Processes an image and returns the detected edges."""
device = next(iter(self.model.parameters())).device
np_image = pil_to_np(image)
height, width, _channels = np_image.shape
new_height = 256 * int(np.ceil(float(height) / 256.0))
new_width = 256 * int(np.ceil(float(width) / 256.0))
resized_img = cv2.resize(np_image, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
with torch.no_grad():
image_feed = torch.from_numpy(resized_img).float().to(device)
image_feed = image_feed / 127.5 - 1.0
image_feed = rearrange(image_feed, "h w c -> 1 c h w")
line = self.model(image_feed)[0, 0] * 127.5 + 127.5
line = line.cpu().numpy()
line = cv2.resize(line, (width, height), interpolation=cv2.INTER_CUBIC)
line = line.clip(0, 255).astype(np.uint8)
detected_map = line
detected_map = 255 - detected_map
output = np_to_pil(detected_map)
return output