From bbc078a364eaf028ec5b16e97c55ef6de61423bb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 31 Dec 2024 18:55:27 +0000 Subject: [PATCH] Add get_effective_device(...) utility to aid in determining the effective device of models that are partially loaded. --- invokeai/backend/image_util/hed.py | 5 +++-- .../backend/image_util/infill_methods/lama.py | 3 ++- invokeai/backend/image_util/lineart.py | 5 +++-- invokeai/backend/image_util/lineart_anime.py | 5 +++-- invokeai/backend/image_util/mlsd/utils.py | 6 ++++-- .../backend/image_util/normal_bae/__init__.py | 3 ++- invokeai/backend/image_util/pidi/__init__.py | 3 ++- .../model_manager/load/model_cache/utils.py | 20 +++++++++++++++++++ 8 files changed, 39 insertions(+), 11 deletions(-) create mode 100644 invokeai/backend/model_manager/load/model_cache/utils.py diff --git a/invokeai/backend/image_util/hed.py b/invokeai/backend/image_util/hed.py index ec12c26b2e..a2d3449f65 100644 --- a/invokeai/backend/image_util/hed.py +++ b/invokeai/backend/image_util/hed.py @@ -18,6 +18,7 @@ from invokeai.backend.image_util.util import ( resize_image_to_resolution, safe_step, ) +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class DoubleConvBlock(torch.nn.Module): @@ -109,7 +110,7 @@ class HEDProcessor: Returns: The detected edges. """ - device = next(iter(self.network.parameters())).device + device = get_effective_device(self.network) np_image = pil_to_np(input_image) np_image = normalize_image_channel_count(np_image) np_image = resize_image_to_resolution(np_image, detect_resolution) @@ -183,7 +184,7 @@ class HEDEdgeDetector: The detected edges. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index cd5838d1f2..faf25e44a4 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -7,6 +7,7 @@ from PIL import Image import invokeai.backend.util.logging as logger from invokeai.backend.model_manager.config import AnyModel +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device def norm_img(np_img): @@ -31,7 +32,7 @@ class LaMA: mask = norm_img(mask) mask = (mask > 0) * 1 - device = next(self._model.buffers()).device + device = get_effective_device(self._model) image = torch.from_numpy(image).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device) diff --git a/invokeai/backend/image_util/lineart.py b/invokeai/backend/image_util/lineart.py index 8fcca24b0e..bfef6f6da0 100644 --- a/invokeai/backend/image_util/lineart.py +++ b/invokeai/backend/image_util/lineart.py @@ -17,6 +17,7 @@ from invokeai.backend.image_util.util import ( pil_to_np, resize_image_to_resolution, ) +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class ResidualBlock(nn.Module): @@ -130,7 +131,7 @@ class LineartProcessor: Returns: The detected lineart. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(input_image) np_image = normalize_image_channel_count(np_image) @@ -201,7 +202,7 @@ class LineartEdgeDetector: Returns: The detected edges. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) diff --git a/invokeai/backend/image_util/lineart_anime.py b/invokeai/backend/image_util/lineart_anime.py index 09dcb6655e..fa406cf1d4 100644 --- a/invokeai/backend/image_util/lineart_anime.py +++ b/invokeai/backend/image_util/lineart_anime.py @@ -19,6 +19,7 @@ from invokeai.backend.image_util.util import ( pil_to_np, resize_image_to_resolution, ) +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class UnetGenerator(nn.Module): @@ -171,7 +172,7 @@ class LineartAnimeProcessor: Returns: The detected lineart. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(input_image) np_image = normalize_image_channel_count(np_image) @@ -239,7 +240,7 @@ class LineartAnimeEdgeDetector: def run(self, image: Image.Image) -> Image.Image: """Processes an image and returns the detected edges.""" - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) diff --git a/invokeai/backend/image_util/mlsd/utils.py b/invokeai/backend/image_util/mlsd/utils.py index dbe9a98d09..dbadce01a4 100644 --- a/invokeai/backend/image_util/mlsd/utils.py +++ b/invokeai/backend/image_util/mlsd/utils.py @@ -14,6 +14,8 @@ import numpy as np import torch from torch.nn import functional as F +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device + def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): ''' @@ -49,7 +51,7 @@ def pred_lines(image, model, dist_thr=20.0): h, w, _ = image.shape - device = next(iter(model.parameters())).device + device = get_effective_device(model) h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), @@ -108,7 +110,7 @@ def pred_squares(image, ''' h, w, _ = image.shape original_shape = [h, w] - device = next(iter(model.parameters())).device + device = get_effective_device(model) resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), np.ones([input_shape[0], input_shape[1], 1])], axis=-1) diff --git a/invokeai/backend/image_util/normal_bae/__init__.py b/invokeai/backend/image_util/normal_bae/__init__.py index d0b1339113..5ad221ecd4 100644 --- a/invokeai/backend/image_util/normal_bae/__init__.py +++ b/invokeai/backend/image_util/normal_bae/__init__.py @@ -13,6 +13,7 @@ from PIL import Image from invokeai.backend.image_util.normal_bae.nets.NNET import NNET from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class NormalMapDetector: @@ -64,7 +65,7 @@ class NormalMapDetector: def run(self, image: Image.Image): """Processes an image and returns the detected normal map.""" - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) height, width, _channels = np_image.shape diff --git a/invokeai/backend/image_util/pidi/__init__.py b/invokeai/backend/image_util/pidi/__init__.py index 8673b21914..63df7b6058 100644 --- a/invokeai/backend/image_util/pidi/__init__.py +++ b/invokeai/backend/image_util/pidi/__init__.py @@ -11,6 +11,7 @@ from PIL import Image from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class PIDINetDetector: @@ -45,7 +46,7 @@ class PIDINetDetector: ) -> Image.Image: """Processes an image and returns the detected edges.""" - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_img = pil_to_np(image) np_img = normalize_image_channel_count(np_img) diff --git a/invokeai/backend/model_manager/load/model_cache/utils.py b/invokeai/backend/model_manager/load/model_cache/utils.py new file mode 100644 index 0000000000..2b581990c6 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/utils.py @@ -0,0 +1,20 @@ +import itertools + +import torch + + +def get_effective_device(model: torch.nn.Module) -> torch.device: + """A utility to infer the 'effective' device of a model. + + This utility handles the case where a model is partially loaded onto the GPU, so is safer than just calling: + `next(iter(model.parameters())).device`. + + In the worst case, this utility has to check all model parameters, so if you already know the intended model device, + then it is better to avoid calling this function. + """ + # If all parameters are on the CPU, return the CPU device. Otherwise, return the first non-CPU device. + for p in itertools.chain(model.parameters(), model.buffers()): + if p.device.type != "cpu": + return p.device + + return torch.device("cpu")