feat(nodes): update pidinet node

Human-readable field names.
This commit is contained in:
psychedelicious 2024-09-11 17:43:55 +10:00 committed by Kent Keirsey
parent a4250e3ff2
commit ee4c0efbf7
2 changed files with 4 additions and 4 deletions

View File

@ -17,7 +17,7 @@ class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an edge map using PiDiNet."""
image: ImageField = InputField(description="The image to process")
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
quantize_edges: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def invoke(self, context: InvocationContext) -> ImageOutput:
@ -27,7 +27,7 @@ class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
with loaded_model as model:
assert isinstance(model, PiDiNet)
detector = PIDINetDetector(model)
edge_map = detector.run(image=image, safe=self.safe, scribble=self.scribble)
edge_map = detector.run(image=image, quantize_edges=self.quantize_edges, scribble=self.scribble)
image_dto = context.images.save(image=edge_map)
return ImageOutput.build(image_dto)

View File

@ -41,7 +41,7 @@ class PIDINetDetector:
return self
def run(
self, image: Image.Image, safe: bool = False, scribble: bool = False, apply_filter: bool = False
self, image: Image.Image, quantize_edges: bool = False, scribble: bool = False, apply_filter: bool = False
) -> Image.Image:
"""Processes an image and returns the detected edges."""
@ -62,7 +62,7 @@ class PIDINetDetector:
edge = edge.cpu().numpy()
if apply_filter:
edge = edge > 0.5
if safe:
if quantize_edges:
edge = safe_step(edge)
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)