mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Add get_effective_device(...) utility to aid in determining the effective device of models that are partially loaded.
This commit is contained in:
parent
c8b4f2f20d
commit
bbc078a364
@ -18,6 +18,7 @@ from invokeai.backend.image_util.util import (
|
||||
resize_image_to_resolution,
|
||||
safe_step,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
class DoubleConvBlock(torch.nn.Module):
|
||||
@ -109,7 +110,7 @@ class HEDProcessor:
|
||||
Returns:
|
||||
The detected edges.
|
||||
"""
|
||||
device = next(iter(self.network.parameters())).device
|
||||
device = get_effective_device(self.network)
|
||||
np_image = pil_to_np(input_image)
|
||||
np_image = normalize_image_channel_count(np_image)
|
||||
np_image = resize_image_to_resolution(np_image, detect_resolution)
|
||||
@ -183,7 +184,7 @@ class HEDEdgeDetector:
|
||||
The detected edges.
|
||||
"""
|
||||
|
||||
device = next(iter(self.model.parameters())).device
|
||||
device = get_effective_device(self.model)
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
|
@ -7,6 +7,7 @@ from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
@ -31,7 +32,7 @@ class LaMA:
|
||||
mask = norm_img(mask)
|
||||
mask = (mask > 0) * 1
|
||||
|
||||
device = next(self._model.buffers()).device
|
||||
device = get_effective_device(self._model)
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
|
@ -17,6 +17,7 @@ from invokeai.backend.image_util.util import (
|
||||
pil_to_np,
|
||||
resize_image_to_resolution,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
@ -130,7 +131,7 @@ class LineartProcessor:
|
||||
Returns:
|
||||
The detected lineart.
|
||||
"""
|
||||
device = next(iter(self.model.parameters())).device
|
||||
device = get_effective_device(self.model)
|
||||
|
||||
np_image = pil_to_np(input_image)
|
||||
np_image = normalize_image_channel_count(np_image)
|
||||
@ -201,7 +202,7 @@ class LineartEdgeDetector:
|
||||
Returns:
|
||||
The detected edges.
|
||||
"""
|
||||
device = next(iter(self.model.parameters())).device
|
||||
device = get_effective_device(self.model)
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
|
@ -19,6 +19,7 @@ from invokeai.backend.image_util.util import (
|
||||
pil_to_np,
|
||||
resize_image_to_resolution,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
class UnetGenerator(nn.Module):
|
||||
@ -171,7 +172,7 @@ class LineartAnimeProcessor:
|
||||
Returns:
|
||||
The detected lineart.
|
||||
"""
|
||||
device = next(iter(self.model.parameters())).device
|
||||
device = get_effective_device(self.model)
|
||||
np_image = pil_to_np(input_image)
|
||||
|
||||
np_image = normalize_image_channel_count(np_image)
|
||||
@ -239,7 +240,7 @@ class LineartAnimeEdgeDetector:
|
||||
|
||||
def run(self, image: Image.Image) -> Image.Image:
|
||||
"""Processes an image and returns the detected edges."""
|
||||
device = next(iter(self.model.parameters())).device
|
||||
device = get_effective_device(self.model)
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
|
@ -14,6 +14,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
|
||||
'''
|
||||
@ -49,7 +51,7 @@ def pred_lines(image, model,
|
||||
dist_thr=20.0):
|
||||
h, w, _ = image.shape
|
||||
|
||||
device = next(iter(model.parameters())).device
|
||||
device = get_effective_device(model)
|
||||
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
|
||||
|
||||
resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
|
||||
@ -108,7 +110,7 @@ def pred_squares(image,
|
||||
'''
|
||||
h, w, _ = image.shape
|
||||
original_shape = [h, w]
|
||||
device = next(iter(model.parameters())).device
|
||||
device = get_effective_device(model)
|
||||
|
||||
resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
|
||||
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
||||
|
@ -13,6 +13,7 @@ from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
class NormalMapDetector:
|
||||
@ -64,7 +65,7 @@ class NormalMapDetector:
|
||||
def run(self, image: Image.Image):
|
||||
"""Processes an image and returns the detected normal map."""
|
||||
|
||||
device = next(iter(self.model.parameters())).device
|
||||
device = get_effective_device(self.model)
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
height, width, _channels = np_image.shape
|
||||
|
@ -11,6 +11,7 @@ from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
|
||||
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
class PIDINetDetector:
|
||||
@ -45,7 +46,7 @@ class PIDINetDetector:
|
||||
) -> Image.Image:
|
||||
"""Processes an image and returns the detected edges."""
|
||||
|
||||
device = next(iter(self.model.parameters())).device
|
||||
device = get_effective_device(self.model)
|
||||
|
||||
np_img = pil_to_np(image)
|
||||
np_img = normalize_image_channel_count(np_img)
|
||||
|
20
invokeai/backend/model_manager/load/model_cache/utils.py
Normal file
20
invokeai/backend/model_manager/load/model_cache/utils.py
Normal file
@ -0,0 +1,20 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_effective_device(model: torch.nn.Module) -> torch.device:
|
||||
"""A utility to infer the 'effective' device of a model.
|
||||
|
||||
This utility handles the case where a model is partially loaded onto the GPU, so is safer than just calling:
|
||||
`next(iter(model.parameters())).device`.
|
||||
|
||||
In the worst case, this utility has to check all model parameters, so if you already know the intended model device,
|
||||
then it is better to avoid calling this function.
|
||||
"""
|
||||
# If all parameters are on the CPU, return the CPU device. Otherwise, return the first non-CPU device.
|
||||
for p in itertools.chain(model.parameters(), model.buffers()):
|
||||
if p.device.type != "cpu":
|
||||
return p.device
|
||||
|
||||
return torch.device("cpu")
|
Loading…
Reference in New Issue
Block a user