mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
from typing import Dict, Optional
|
|
|
|
import torch
|
|
|
|
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
|
|
|
|
|
class IA3Layer(LoRALayerBase):
|
|
"""IA3 Layer
|
|
|
|
Example model for testing this layer type: https://civitai.com/models/123930/gwendolyn-tennyson-ben-10-ia3
|
|
"""
|
|
|
|
def __init__(self, weight: torch.Tensor, on_input: torch.Tensor, bias: Optional[torch.Tensor]):
|
|
super().__init__(alpha=None, bias=bias)
|
|
self.weight = weight
|
|
self.on_input = on_input
|
|
|
|
def _rank(self) -> int | None:
|
|
return None
|
|
|
|
@classmethod
|
|
def from_state_dict_values(
|
|
cls,
|
|
values: Dict[str, torch.Tensor],
|
|
):
|
|
bias = cls._parse_bias(
|
|
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
|
|
)
|
|
layer = cls(
|
|
weight=values["weight"],
|
|
on_input=values["on_input"],
|
|
bias=bias,
|
|
)
|
|
cls.warn_on_unhandled_keys(
|
|
values=values,
|
|
handled_keys={
|
|
# Default keys.
|
|
"bias_indices",
|
|
"bias_values",
|
|
"bias_size",
|
|
# Layer-specific keys.
|
|
"weight",
|
|
"on_input",
|
|
},
|
|
)
|
|
return layer
|
|
|
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
weight = self.weight
|
|
if not self.on_input:
|
|
weight = weight.reshape(-1, 1)
|
|
return orig_weight * weight
|
|
|
|
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
|
super().to(device, dtype)
|
|
self.weight = self.weight.to(device, dtype)
|
|
self.on_input = self.on_input.to(device, dtype)
|