mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
WIP - add invocations to support FLUX LORAs.
This commit is contained in:
parent
db61ec4322
commit
01a15b4d80
@ -1,4 +1,4 @@
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
@ -29,6 +29,8 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@ -187,7 +189,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
with transformer_info as transformer:
|
||||
with (
|
||||
transformer_info.model_on_device() as (cached_weights, transformer),
|
||||
# Apply the LoRA after transformer has been moved to its target device for faster patching.
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix="transformer",
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
@ -247,6 +258,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
||||
|
53
invokeai/app/invocations/flux_lora_loader.py
Normal file
53
invokeai/app/invocations/flux_lora_loader.py
Normal file
@ -0,0 +1,53 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_lora_loader_output")
|
||||
class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""FLUX LoRA Loader Output"""
|
||||
|
||||
transformer: TransformerField = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_lora_loader",
|
||||
title="FLUX LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
transformer = self.transformer.model_copy(deep=True)
|
||||
transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return FluxLoRALoaderOutput(transformer=transformer)
|
@ -69,6 +69,7 @@ class CLIPField(BaseModel):
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
@ -202,7 +203,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer),
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
|
102
invokeai/backend/lora/lora_patcher.py
Normal file
102
invokeai/backend/lora/lora_patcher.py
Normal file
@ -0,0 +1,102 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class LoRAPatcher:
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_patches(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
patches: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Apply one or more LoRA patches to a model.
|
||||
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so
|
||||
that the LoRA patches do not need to be loaded into memory all at once.
|
||||
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
|
||||
"""
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
cls._apply_lora_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
|
||||
yield
|
||||
finally:
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def _apply_lora_patch(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""
|
||||
Apply one a LoRA to a model.
|
||||
:param model: The model to patch.
|
||||
:param patch: LoRA model to patch in.
|
||||
:param patch_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||
"""
|
||||
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module = model.get_submodule(layer_key)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = layer_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
@ -10,6 +10,7 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.lora.conversions.flux_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
@ -26,7 +27,6 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.peft.conversions.flux_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user