mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-04-04 22:43:40 +08:00
Flux Vae broke for float16, force bfloat16 or float32 were compatible
This commit is contained in:
parent
a01d44f813
commit
7b5efc2203
@ -318,6 +318,13 @@ class AutoEncoder(nn.Module):
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
|
||||
# VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32
|
||||
if z.dtype == torch.float16:
|
||||
try:
|
||||
z = z.to(torch.bfloat16)
|
||||
except TypeError:
|
||||
z = z.to(torch.float32)
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
@ -35,6 +35,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
self._torch_device = TorchDevice.choose_torch_device()
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
|
@ -84,7 +84,15 @@ class FluxVAELoader(ModelLoader):
|
||||
model = AutoEncoder(ae_params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
# VAE is broken in float16, which mps defaults too
|
||||
if self._torch_dtype == torch.float16:
|
||||
try:
|
||||
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
|
||||
except TypeError:
|
||||
vae_dtype = torch.tensor([1.0], dtype=torch.float32, device=self._torch_device).dtype
|
||||
else:
|
||||
vae_dtype = self._torch_dtype
|
||||
model.to(vae_dtype)
|
||||
|
||||
return model
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user