Add get_effective_device(...) utility to aid in determining the effective device of models that are partially loaded.

This commit is contained in:
Ryan Dick 2024-12-31 18:55:27 +00:00
parent c8b4f2f20d
commit bbc078a364
8 changed files with 39 additions and 11 deletions

View File

@ -18,6 +18,7 @@ from invokeai.backend.image_util.util import (
resize_image_to_resolution,
safe_step,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
class DoubleConvBlock(torch.nn.Module):
@ -109,7 +110,7 @@ class HEDProcessor:
Returns:
The detected edges.
"""
device = next(iter(self.network.parameters())).device
device = get_effective_device(self.network)
np_image = pil_to_np(input_image)
np_image = normalize_image_channel_count(np_image)
np_image = resize_image_to_resolution(np_image, detect_resolution)
@ -183,7 +184,7 @@ class HEDEdgeDetector:
The detected edges.
"""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)
np_image = pil_to_np(image)

View File

@ -7,6 +7,7 @@ from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
def norm_img(np_img):
@ -31,7 +32,7 @@ class LaMA:
mask = norm_img(mask)
mask = (mask > 0) * 1
device = next(self._model.buffers()).device
device = get_effective_device(self._model)
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)

View File

@ -17,6 +17,7 @@ from invokeai.backend.image_util.util import (
pil_to_np,
resize_image_to_resolution,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
class ResidualBlock(nn.Module):
@ -130,7 +131,7 @@ class LineartProcessor:
Returns:
The detected lineart.
"""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)
np_image = pil_to_np(input_image)
np_image = normalize_image_channel_count(np_image)
@ -201,7 +202,7 @@ class LineartEdgeDetector:
Returns:
The detected edges.
"""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)
np_image = pil_to_np(image)

View File

@ -19,6 +19,7 @@ from invokeai.backend.image_util.util import (
pil_to_np,
resize_image_to_resolution,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
class UnetGenerator(nn.Module):
@ -171,7 +172,7 @@ class LineartAnimeProcessor:
Returns:
The detected lineart.
"""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)
np_image = pil_to_np(input_image)
np_image = normalize_image_channel_count(np_image)
@ -239,7 +240,7 @@ class LineartAnimeEdgeDetector:
def run(self, image: Image.Image) -> Image.Image:
"""Processes an image and returns the detected edges."""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)
np_image = pil_to_np(image)

View File

@ -14,6 +14,8 @@ import numpy as np
import torch
from torch.nn import functional as F
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
'''
@ -49,7 +51,7 @@ def pred_lines(image, model,
dist_thr=20.0):
h, w, _ = image.shape
device = next(iter(model.parameters())).device
device = get_effective_device(model)
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
@ -108,7 +110,7 @@ def pred_squares(image,
'''
h, w, _ = image.shape
original_shape = [h, w]
device = next(iter(model.parameters())).device
device = get_effective_device(model)
resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)

View File

@ -13,6 +13,7 @@ from PIL import Image
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
class NormalMapDetector:
@ -64,7 +65,7 @@ class NormalMapDetector:
def run(self, image: Image.Image):
"""Processes an image and returns the detected normal map."""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)
np_image = pil_to_np(image)
height, width, _channels = np_image.shape

View File

@ -11,6 +11,7 @@ from PIL import Image
from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
class PIDINetDetector:
@ -45,7 +46,7 @@ class PIDINetDetector:
) -> Image.Image:
"""Processes an image and returns the detected edges."""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)
np_img = pil_to_np(image)
np_img = normalize_image_channel_count(np_img)

View File

@ -0,0 +1,20 @@
import itertools
import torch
def get_effective_device(model: torch.nn.Module) -> torch.device:
"""A utility to infer the 'effective' device of a model.
This utility handles the case where a model is partially loaded onto the GPU, so is safer than just calling:
`next(iter(model.parameters())).device`.
In the worst case, this utility has to check all model parameters, so if you already know the intended model device,
then it is better to avoid calling this function.
"""
# If all parameters are on the CPU, return the CPU device. Otherwise, return the first non-CPU device.
for p in itertools.chain(model.parameters(), model.buffers()):
if p.device.type != "cpu":
return p.device
return torch.device("cpu")