mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui
synced 2025-01-08 12:07:30 +08:00
Merge pull request #14478 from akx/dtype-inspect
Add utility to inspect a model's dtype/device
This commit is contained in:
commit
be5f1acc8f
@ -4,6 +4,7 @@ from functools import lru_cache
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from modules import errors, shared
|
from modules import errors, shared
|
||||||
|
from modules.torch_utils import get_param
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
from modules import mac_specific
|
from modules import mac_specific
|
||||||
@ -131,7 +132,7 @@ patch_module_list = [
|
|||||||
|
|
||||||
|
|
||||||
def manual_cast_forward(self, *args, **kwargs):
|
def manual_cast_forward(self, *args, **kwargs):
|
||||||
org_dtype = next(self.parameters()).dtype
|
org_dtype = get_param(self).dtype
|
||||||
self.to(dtype)
|
self.to(dtype)
|
||||||
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||||
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||||
|
@ -11,6 +11,7 @@ from torchvision import transforms
|
|||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||||
|
from modules.torch_utils import get_param
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
clip_model_name = 'ViT-L/14'
|
clip_model_name = 'ViT-L/14'
|
||||||
@ -131,7 +132,7 @@ class InterrogateModels:
|
|||||||
|
|
||||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||||
|
|
||||||
self.dtype = next(self.clip_model.parameters()).dtype
|
self.dtype = get_param(self.clip_model).dtype
|
||||||
|
|
||||||
def send_clip_to_ram(self):
|
def send_clip_to_ram(self):
|
||||||
if not shared.opts.interrogate_keep_models_in_memory:
|
if not shared.opts.interrogate_keep_models_in_memory:
|
||||||
|
@ -6,6 +6,7 @@ import sgm.models.diffusion
|
|||||||
import sgm.modules.diffusionmodules.denoiser_scaling
|
import sgm.modules.diffusionmodules.denoiser_scaling
|
||||||
import sgm.modules.diffusionmodules.discretizer
|
import sgm.modules.diffusionmodules.discretizer
|
||||||
from modules import devices, shared, prompt_parser
|
from modules import devices, shared, prompt_parser
|
||||||
|
from modules.torch_utils import get_param
|
||||||
|
|
||||||
|
|
||||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||||
@ -90,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
|
|||||||
def extend_sdxl(model):
|
def extend_sdxl(model):
|
||||||
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||||
|
|
||||||
dtype = next(model.model.diffusion_model.parameters()).dtype
|
dtype = get_param(model.model.diffusion_model).dtype
|
||||||
model.model.diffusion_model.dtype = dtype
|
model.model.diffusion_model.dtype = dtype
|
||||||
model.model.conditioning_key = 'crossattn'
|
model.model.conditioning_key = 'crossattn'
|
||||||
model.cond_stage_key = 'txt'
|
model.cond_stage_key = 'txt'
|
||||||
|
17
modules/torch_utils.py
Normal file
17
modules/torch_utils.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch.nn
|
||||||
|
|
||||||
|
|
||||||
|
def get_param(model) -> torch.nn.Parameter:
|
||||||
|
"""
|
||||||
|
Find the first parameter in a model or module.
|
||||||
|
"""
|
||||||
|
if hasattr(model, "model") and hasattr(model.model, "parameters"):
|
||||||
|
# Unpeel a model descriptor to get at the actual Torch module.
|
||||||
|
model = model.model
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
return param
|
||||||
|
|
||||||
|
raise ValueError(f"No parameters found in model {model!r}")
|
@ -7,6 +7,7 @@ import tqdm
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from modules import images, shared
|
from modules import images, shared
|
||||||
|
from modules.torch_utils import get_param
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image):
|
|||||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||||
img = torch.from_numpy(img).float()
|
img = torch.from_numpy(img).float()
|
||||||
|
|
||||||
model_weight = next(iter(model.model.parameters()))
|
param = get_param(model)
|
||||||
img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype)
|
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(img)
|
output = model(img)
|
||||||
|
@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
|
|||||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from modules.torch_utils import get_param
|
||||||
|
|
||||||
|
|
||||||
class BertSeriesConfig(BertConfig):
|
class BertSeriesConfig(BertConfig):
|
||||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||||
|
|
||||||
@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
|||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def encode(self,c):
|
def encode(self,c):
|
||||||
device = next(self.parameters()).device
|
device = get_param(self).device
|
||||||
text = self.tokenizer(c,
|
text = self.tokenizer(c,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=77,
|
max_length=77,
|
||||||
|
@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
|
|||||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from modules.torch_utils import get_param
|
||||||
|
|
||||||
|
|
||||||
class BertSeriesConfig(BertConfig):
|
class BertSeriesConfig(BertConfig):
|
||||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||||
|
|
||||||
@ -68,7 +71,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
|||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def encode(self,c):
|
def encode(self,c):
|
||||||
device = next(self.parameters()).device
|
device = get_param(self).device
|
||||||
text = self.tokenizer(c,
|
text = self.tokenizer(c,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=77,
|
max_length=77,
|
||||||
|
19
test/test_torch_utils.py
Normal file
19
test/test_torch_utils.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import types
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules.torch_utils import get_param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("wrapped", [True, False])
|
||||||
|
def test_get_param(wrapped):
|
||||||
|
mod = torch.nn.Linear(1, 1)
|
||||||
|
cpu = torch.device("cpu")
|
||||||
|
mod.to(dtype=torch.float16, device=cpu)
|
||||||
|
if wrapped:
|
||||||
|
# more or less how spandrel wraps a thing
|
||||||
|
mod = types.SimpleNamespace(model=mod)
|
||||||
|
p = get_param(mod)
|
||||||
|
assert p.dtype == torch.float16
|
||||||
|
assert p.device == cpu
|
Loading…
Reference in New Issue
Block a user