Flux Vae broke for float16, force bfloat16 or float32 were compatible

This commit is contained in:
David Burnett 2024-10-27 14:18:12 +00:00 committed by Kent Keirsey
parent a01d44f813
commit 7b5efc2203
3 changed files with 17 additions and 1 deletions

View File

@ -318,6 +318,13 @@ class AutoEncoder(nn.Module):
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
# VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32
if z.dtype == torch.float16:
try:
z = z.to(torch.bfloat16)
except TypeError:
z = z.to(torch.float32)
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:

View File

@ -35,6 +35,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger
self._ram_cache = ram_cache
self._torch_dtype = TorchDevice.choose_torch_dtype()
self._torch_device = TorchDevice.choose_torch_device()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""

View File

@ -84,7 +84,15 @@ class FluxVAELoader(ModelLoader):
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
# VAE is broken in float16, which mps defaults too
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.tensor([1.0], dtype=torch.float32, device=self._torch_device).dtype
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)
return model