mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
feat(nodes): add LineartAnimeEdgeDetectionInvocation
Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager.
This commit is contained in:
parent
1cffcc02a5
commit
cd2c2a7fde
31
invokeai/app/invocations/lineart_anime.py
Normal file
31
invokeai/app/invocations/lineart_anime.py
Normal file
@ -0,0 +1,31 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeEdgeDetector, UnetGenerator
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_anime_edge_detection",
|
||||
title="Lineart Anime Edge Detection",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Geneartes an edge map using the Lineart model."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
model_url = LineartAnimeEdgeDetector.get_model_url()
|
||||
loaded_model = context.models.load_remote_model(model_url, LineartAnimeEdgeDetector.load_model)
|
||||
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, UnetGenerator)
|
||||
detector = LineartAnimeEdgeDetector(model)
|
||||
edge_map = detector.run(image=image)
|
||||
|
||||
image_dto = context.images.save(image=edge_map)
|
||||
return ImageOutput.build(image_dto)
|
@ -1,9 +1,11 @@
|
||||
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
|
||||
|
||||
import functools
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -201,3 +203,65 @@ class LineartAnimeProcessor:
|
||||
detected_map = 255 - detected_map
|
||||
|
||||
return np_to_pil(detected_map)
|
||||
|
||||
|
||||
class LineartAnimeEdgeDetector:
|
||||
"""Simple wrapper around the Lineart Anime model for detecting edges in an image."""
|
||||
|
||||
hf_repo_id = "lllyasviel/Annotators"
|
||||
hf_filename = "netG.pth"
|
||||
|
||||
@classmethod
|
||||
def get_model_url(cls) -> str:
|
||||
"""Get the URL to download the model from the Hugging Face Hub."""
|
||||
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path: pathlib.Path) -> UnetGenerator:
|
||||
"""Load the model from a file."""
|
||||
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||
model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
|
||||
ckpt = torch.load(model_path)
|
||||
for key in list(ckpt.keys()):
|
||||
if "module." in key:
|
||||
ckpt[key.replace("module.", "")] = ckpt[key]
|
||||
del ckpt[key]
|
||||
model.load_state_dict(ckpt)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def __init__(self, model: UnetGenerator) -> None:
|
||||
self.model = model
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def run(self, image: Image.Image) -> Image.Image:
|
||||
"""Processes an image and returns the detected edges."""
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
height, width, _channels = np_image.shape
|
||||
new_height = 256 * int(np.ceil(float(height) / 256.0))
|
||||
new_width = 256 * int(np.ceil(float(width) / 256.0))
|
||||
|
||||
resized_img = cv2.resize(np_image, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
with torch.no_grad():
|
||||
image_feed = torch.from_numpy(resized_img).float().to(device)
|
||||
image_feed = image_feed / 127.5 - 1.0
|
||||
image_feed = rearrange(image_feed, "h w c -> 1 c h w")
|
||||
|
||||
line = self.model(image_feed)[0, 0] * 127.5 + 127.5
|
||||
line = line.cpu().numpy()
|
||||
|
||||
line = cv2.resize(line, (width, height), interpolation=cv2.INTER_CUBIC)
|
||||
line = line.clip(0, 255).astype(np.uint8)
|
||||
|
||||
detected_map = line
|
||||
detected_map = 255 - detected_map
|
||||
output = np_to_pil(detected_map)
|
||||
|
||||
return output
|
||||
|
Loading…
x
Reference in New Issue
Block a user