2023-06-21 02:12:21 +03:00
|
|
|
from enum import Enum
|
2023-08-17 18:45:25 -04:00
|
|
|
from typing import Literal
|
2023-06-21 02:12:21 +03:00
|
|
|
from .base import (
|
|
|
|
ModelConfigBase,
|
|
|
|
BaseModelType,
|
|
|
|
ModelType,
|
|
|
|
ModelVariantType,
|
|
|
|
DiffusersModel,
|
|
|
|
SchedulerPredictionType,
|
|
|
|
classproperty,
|
|
|
|
OnnxRuntimeModel,
|
|
|
|
IAIOnnxRuntimeModel,
|
|
|
|
)
|
|
|
|
|
2023-07-27 09:37:37 -04:00
|
|
|
|
|
|
|
class StableDiffusionOnnxModelFormat(str, Enum):
|
|
|
|
Olive = "olive"
|
|
|
|
Onnx = "onnx"
|
|
|
|
|
2023-06-21 02:12:21 +03:00
|
|
|
|
2023-07-28 09:46:44 -04:00
|
|
|
class ONNXStableDiffusion1Model(DiffusersModel):
|
2023-06-21 02:12:21 +03:00
|
|
|
class Config(ModelConfigBase):
|
2023-07-27 09:37:37 -04:00
|
|
|
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
2023-06-21 02:12:21 +03:00
|
|
|
variant: ModelVariantType
|
|
|
|
|
|
|
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
|
|
assert base_model == BaseModelType.StableDiffusion1
|
|
|
|
assert model_type == ModelType.ONNX
|
|
|
|
super().__init__(
|
|
|
|
model_path=model_path,
|
|
|
|
base_model=BaseModelType.StableDiffusion1,
|
|
|
|
model_type=ModelType.ONNX,
|
|
|
|
)
|
|
|
|
|
|
|
|
for child_name, child_type in self.child_types.items():
|
|
|
|
if child_type is OnnxRuntimeModel:
|
|
|
|
self.child_types[child_name] = IAIOnnxRuntimeModel
|
|
|
|
|
|
|
|
# TODO: check that no optimum models provided
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def probe_config(cls, path: str, **kwargs):
|
|
|
|
model_format = cls.detect_format(path)
|
2023-07-28 09:46:44 -04:00
|
|
|
in_channels = 4 # TODO:
|
2023-06-21 02:12:21 +03:00
|
|
|
|
|
|
|
if in_channels == 9:
|
|
|
|
variant = ModelVariantType.Inpaint
|
|
|
|
elif in_channels == 4:
|
|
|
|
variant = ModelVariantType.Normal
|
|
|
|
else:
|
|
|
|
raise Exception("Unkown stable diffusion 1.* model format")
|
|
|
|
|
|
|
|
return cls.create_config(
|
|
|
|
path=path,
|
|
|
|
model_format=model_format,
|
|
|
|
variant=variant,
|
|
|
|
)
|
|
|
|
|
|
|
|
@classproperty
|
|
|
|
def save_to_config(cls) -> bool:
|
|
|
|
return True
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def detect_format(cls, model_path: str):
|
2023-07-27 09:37:37 -04:00
|
|
|
# TODO: Detect onnx vs olive
|
|
|
|
return StableDiffusionOnnxModelFormat.Onnx
|
2023-06-21 02:12:21 +03:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def convert_if_required(
|
|
|
|
cls,
|
|
|
|
model_path: str,
|
|
|
|
output_path: str,
|
|
|
|
config: ModelConfigBase,
|
|
|
|
base_model: BaseModelType,
|
|
|
|
) -> str:
|
|
|
|
return model_path
|
|
|
|
|
|
|
|
|
2023-07-28 09:46:44 -04:00
|
|
|
class ONNXStableDiffusion2Model(DiffusersModel):
|
2023-06-21 02:12:21 +03:00
|
|
|
# TODO: check that configs overwriten properly
|
|
|
|
class Config(ModelConfigBase):
|
2023-07-27 09:37:37 -04:00
|
|
|
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
2023-06-21 02:12:21 +03:00
|
|
|
variant: ModelVariantType
|
|
|
|
prediction_type: SchedulerPredictionType
|
|
|
|
upcast_attention: bool
|
|
|
|
|
|
|
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
|
|
assert base_model == BaseModelType.StableDiffusion2
|
|
|
|
assert model_type == ModelType.ONNX
|
|
|
|
super().__init__(
|
|
|
|
model_path=model_path,
|
|
|
|
base_model=BaseModelType.StableDiffusion2,
|
|
|
|
model_type=ModelType.ONNX,
|
|
|
|
)
|
|
|
|
|
|
|
|
for child_name, child_type in self.child_types.items():
|
|
|
|
if child_type is OnnxRuntimeModel:
|
|
|
|
self.child_types[child_name] = IAIOnnxRuntimeModel
|
|
|
|
# TODO: check that no optimum models provided
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def probe_config(cls, path: str, **kwargs):
|
|
|
|
model_format = cls.detect_format(path)
|
2023-07-28 09:46:44 -04:00
|
|
|
in_channels = 4 # TODO:
|
2023-06-21 02:12:21 +03:00
|
|
|
|
|
|
|
if in_channels == 9:
|
|
|
|
variant = ModelVariantType.Inpaint
|
|
|
|
elif in_channels == 5:
|
|
|
|
variant = ModelVariantType.Depth
|
|
|
|
elif in_channels == 4:
|
|
|
|
variant = ModelVariantType.Normal
|
|
|
|
else:
|
|
|
|
raise Exception("Unkown stable diffusion 2.* model format")
|
|
|
|
|
|
|
|
if variant == ModelVariantType.Normal:
|
|
|
|
prediction_type = SchedulerPredictionType.VPrediction
|
|
|
|
upcast_attention = True
|
|
|
|
|
|
|
|
else:
|
|
|
|
prediction_type = SchedulerPredictionType.Epsilon
|
|
|
|
upcast_attention = False
|
|
|
|
|
|
|
|
return cls.create_config(
|
|
|
|
path=path,
|
|
|
|
model_format=model_format,
|
|
|
|
variant=variant,
|
|
|
|
prediction_type=prediction_type,
|
|
|
|
upcast_attention=upcast_attention,
|
|
|
|
)
|
|
|
|
|
|
|
|
@classproperty
|
|
|
|
def save_to_config(cls) -> bool:
|
|
|
|
return True
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def detect_format(cls, model_path: str):
|
2023-07-27 09:37:37 -04:00
|
|
|
# TODO: Detect onnx vs olive
|
|
|
|
return StableDiffusionOnnxModelFormat.Onnx
|
2023-06-21 02:12:21 +03:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def convert_if_required(
|
|
|
|
cls,
|
|
|
|
model_path: str,
|
|
|
|
output_path: str,
|
|
|
|
config: ModelConfigBase,
|
|
|
|
base_model: BaseModelType,
|
|
|
|
) -> str:
|
|
|
|
return model_path
|