Merge branch 'main' into pin-timm-for-llava

This commit is contained in:
Eugene Brodsky 2025-03-28 09:20:30 -04:00 committed by GitHub
commit c9992914d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 80 additions and 67 deletions

View File

@ -1,7 +1,6 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@ -25,7 +24,6 @@ class FluxControlLoRALoaderOutput(BaseInvocationOutput):
tags=["lora", "model", "flux"],
category="model",
version="1.1.1",
classification=Classification.Prototype,
)
class FluxControlLoRALoaderInvocation(BaseInvocation):
"""LoRA model and Image to use with FLUX transformer generation."""

View File

@ -3,7 +3,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@ -52,7 +51,6 @@ class FluxControlNetOutput(BaseInvocationOutput):
tags=["controlnet", "flux"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxControlNetInvocation(BaseInvocation):
"""Collect FLUX ControlNet info to pass to other nodes."""

View File

@ -10,7 +10,7 @@ from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
@ -64,7 +64,6 @@ from invokeai.backend.util.devices import TorchDevice
tags=["image", "flux"],
category="image",
version="3.3.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a FLUX transformer model."""

View File

@ -31,7 +31,7 @@ class FluxFillOutput(BaseInvocationOutput):
tags=["inpaint"],
category="inpaint",
version="1.0.0",
classification=Classification.Prototype,
classification=Classification.Beta,
)
class FluxFillInvocation(BaseInvocation):
"""Prepare the FLUX Fill conditioning data."""

View File

@ -4,7 +4,7 @@ from typing import List, Literal, Union
from pydantic import field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField, UIType
from invokeai.app.invocations.ip_adapter import (
CLIP_VISION_MODEL_MAP,
@ -28,7 +28,6 @@ from invokeai.backend.model_manager.config import (
tags=["ip_adapter", "control"],
category="ip_adapter",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxIPAdapterInvocation(BaseInvocation):
"""Collects FLUX IP-Adapter info to pass to other nodes."""

View File

@ -3,7 +3,6 @@ from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@ -32,7 +31,6 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
tags=["lora", "model", "flux"],
category="model",
version="1.2.1",
classification=Classification.Prototype,
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
@ -111,7 +109,6 @@ class FluxLoRALoaderInvocation(BaseInvocation):
tags=["lora", "model", "flux"],
category="model",
version="1.3.1",
classification=Classification.Prototype,
)
class FLUXLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to a FLUX transformer."""

View File

@ -3,7 +3,6 @@ from typing import Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@ -41,7 +40,6 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
tags=["model", "flux"],
category="model",
version="1.0.6",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""

View File

@ -45,7 +45,7 @@ class FluxReduxOutput(BaseInvocationOutput):
tags=["ip_adapter", "control"],
category="ip_adapter",
version="2.0.0",
classification=Classification.Prototype,
classification=Classification.Beta,
)
class FluxReduxInvocation(BaseInvocation):
"""Runs a FLUX Redux model to generate a conditioning tensor."""

View File

@ -4,7 +4,7 @@ from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
@ -30,7 +30,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.1.2",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a flux image."""

View File

@ -355,7 +355,6 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
tags=["image", "unsharp_mask"],
category="image",
version="1.2.2",
classification=Classification.Beta,
)
class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Applies an unsharp mask filter to an image"""
@ -1096,6 +1095,7 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
If the fade size is 0, the mask is returned as-is.
"""
mask: ImageField = InputField(description="The mask to expand")
@ -1105,6 +1105,11 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")
if self.fade_size_px == 0:
# If the fade size is 0, just return the mask as-is.
image_dto = context.images.save(image=pil_mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
np_mask = numpy.array(pil_mask)
# Threshold the mask to create a binary mask - 0 for black, 255 for white
@ -1265,7 +1270,6 @@ class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
category="image",
version="1.0.0",
tags=["image", "crop"],
classification=Classification.Beta,
)
class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Crop an image to the given bounding box. If the bounding box is omitted, the image is cropped to the non-transparent pixels."""
@ -1292,7 +1296,6 @@ class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
category="image",
version="1.0.0",
tags=["image", "crop"],
classification=Classification.Beta,
)
class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Paste the source image into the target image at the given bounding box.

View File

@ -4,7 +4,7 @@ import torch
from PIL.Image import Image
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
@ -13,7 +13,14 @@ from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.util.devices import TorchDevice
@invocation("llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], category="vllm", version="1.0.0")
@invocation(
"llava_onevision_vllm",
title="LLaVA OneVision VLLM",
tags=["vllm"],
category="vllm",
version="1.0.0",
classification=Classification.Beta,
)
class LlavaOnevisionVllmInvocation(BaseInvocation):
"""Run a LLaVA OneVision VLLM model."""

View File

@ -4,7 +4,6 @@ from PIL import Image
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
InvocationContext,
invocation,
)
@ -58,7 +57,6 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
tags=["conditioning"],
category="conditioning",
version="1.0.0",
classification=Classification.Beta,
)
class AlphaMaskToTensorInvocation(BaseInvocation):
"""Convert a mask image to a tensor. Opaque regions are 1 and transparent regions are 0."""
@ -87,7 +85,6 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
tags=["conditioning"],
category="conditioning",
version="1.1.0",
classification=Classification.Beta,
)
class InvertTensorMaskInvocation(BaseInvocation):
"""Inverts a tensor mask."""
@ -234,7 +231,6 @@ WHITE = ColorField(r=255, g=255, b=255, a=255)
tags=["mask"],
category="mask",
version="1.0.0",
classification=Classification.Beta,
)
class GetMaskBoundingBoxInvocation(BaseInvocation):
"""Gets the bounding box of the given mask image."""

View File

@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@ -124,7 +123,6 @@ class ModelIdentifierOutput(BaseInvocationOutput):
tags=["model"],
category="model",
version="1.0.1",
classification=Classification.Prototype,
)
class ModelIdentifierInvocation(BaseInvocation):
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as

View File

@ -6,7 +6,7 @@ from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from torchvision.transforms.functional import resize as tv_resize
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
DenoiseMaskField,
@ -36,7 +36,6 @@ from invokeai.backend.util.devices import TorchDevice
tags=["image", "sd3"],
category="image",
version="1.1.1",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a SD3 model."""

View File

@ -2,7 +2,7 @@ import einops
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@ -25,7 +25,6 @@ from invokeai.backend.util.devices import TorchDevice
tags=["image", "latents", "vae", "i2l", "sd3"],
category="image",
version="1.0.1",
classification=Classification.Prototype,
)
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image."""

View File

@ -3,7 +3,6 @@ from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@ -34,7 +33,6 @@ class Sd3ModelLoaderOutput(BaseInvocationOutput):
tags=["model", "sd3"],
category="model",
version="1.0.1",
classification=Classification.Prototype,
)
class Sd3ModelLoaderInvocation(BaseInvocation):
"""Loads a SD3 base model, outputting its submodels."""

View File

@ -11,7 +11,7 @@ from transformers import (
T5TokenizerFast,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
@ -33,7 +33,6 @@ SD3_T5_MAX_SEQ_LEN = 256
tags=["prompt", "conditioning", "sd3"],
category="conditioning",
version="1.0.1",
classification=Classification.Prototype,
)
class Sd3TextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a SD3 image."""

View File

@ -7,7 +7,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
@ -56,7 +56,6 @@ def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> C
title="Tiled Multi-Diffusion Denoise - SD1.5, SDXL",
tags=["upscale", "denoise"],
category="latents",
classification=Classification.Beta,
version="1.0.1",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):

View File

@ -7,7 +7,6 @@ from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@ -40,7 +39,6 @@ class CalculateImageTilesOutput(BaseInvocationOutput):
tags=["tiles"],
category="tiles",
version="1.0.1",
classification=Classification.Beta,
)
class CalculateImageTilesInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
@ -74,7 +72,6 @@ class CalculateImageTilesInvocation(BaseInvocation):
tags=["tiles"],
category="tiles",
version="1.1.1",
classification=Classification.Beta,
)
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
@ -117,7 +114,6 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
tags=["tiles"],
category="tiles",
version="1.0.1",
classification=Classification.Beta,
)
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
@ -168,7 +164,6 @@ class TileToPropertiesOutput(BaseInvocationOutput):
tags=["tiles"],
category="tiles",
version="1.0.1",
classification=Classification.Beta,
)
class TileToPropertiesInvocation(BaseInvocation):
"""Split a Tile into its individual properties."""
@ -201,7 +196,6 @@ class PairTileImageOutput(BaseInvocationOutput):
tags=["tiles"],
category="tiles",
version="1.0.1",
classification=Classification.Beta,
)
class PairTileImageInvocation(BaseInvocation):
"""Pair an image with its tile properties."""
@ -230,7 +224,6 @@ BLEND_MODES = Literal["Linear", "Seam"]
tags=["tiles"],
category="tiles",
version="1.1.1",
classification=Classification.Beta,
)
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Merge multiple tile images into a single image."""

View File

@ -47,3 +47,10 @@ class LlavaOnevisionModel(RawModel):
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
self._vllm_model.to(device=device, dtype=dtype)
def calc_size(self) -> int:
"""Get size of the model in memory in bytes."""
# HACK(ryand): Fix this issue with circular imports.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._vllm_model)

View File

@ -67,6 +67,11 @@ class InvalidModelConfigException(Exception):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class FSLayout(Enum):
FILE = "file"
DIRECTORY = "directory"
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
@ -102,29 +107,31 @@ class ModelOnDisk:
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
# TODO: Revisit checkpoint vs diffusers terminology
self.layout = FSLayout.DIRECTORY if path.is_dir() else FSLayout.FILE
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
self._state_dict_cache = {}
def hash(self):
def hash(self) -> str:
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def size(self):
if self.format_type == ModelFormat.Checkpoint:
def size(self) -> int:
if self.layout == FSLayout.FILE:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self):
if self.format_type == ModelFormat.Checkpoint:
def component_paths(self) -> set[Path]:
if self.layout == FSLayout.FILE:
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def repo_variant(self):
if self.format_type == ModelFormat.Checkpoint:
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.layout == FSLayout.FILE:
return None
weight_files = list(self.path.glob("**/*.safetensors"))
@ -140,14 +147,30 @@ class ModelOnDisk:
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
@staticmethod
def load_state_dict(path: Path):
def load_state_dict(self, path: Optional[Path] = None) -> Dict[str | int, Any]:
if path in self._state_dict_cache:
return self._state_dict_cache[path]
if not path:
components = list(self.component_paths())
match components:
case []:
raise ValueError("No weight files found for this model")
case [p]:
path = p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
assert isinstance(checkpoint, dict)
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
@ -156,6 +179,7 @@ class ModelOnDisk:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
self._state_dict_cache[path] = state_dict
return state_dict
@ -238,11 +262,13 @@ class ModelConfigBase(ABC, BaseModel):
for config_cls in sorted_by_match_speed:
try:
return config_cls.from_model_on_disk(mod, **overrides)
except InvalidModelConfigException:
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
if not config_cls.matches(mod):
continue
except Exception as e:
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
continue
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("No valid config found")
@ -285,9 +311,6 @@ class ModelConfigBase(ABC, BaseModel):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
if not cls.matches(mod):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
fields = cls.parse(mod)
cls.cast_overrides(overrides)
fields.update(overrides)
@ -563,7 +586,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.format_type == ModelFormat.Checkpoint:
if mod.layout == FSLayout.FILE:
return False
config_path = mod.path / "config.json"

View File

@ -15,6 +15,7 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.model_manager.taxonomy import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@ -50,6 +51,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
SegmentAnythingPipeline,
DepthAnythingPipeline,
SigLipPipeline,
LlavaOnevisionModel,
),
):
return model.calc_size()

View File

@ -2344,8 +2344,9 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"Workflows: New and improved Workflow Library.",
"FLUX: Support for FLUX Redux & FLUX Fill in Workflows and Canvas."
"Workflows: Support for custom string drop-downs in Workflow Builder.",
"FLUX: Support for FLUX Fill in Workflows and Canvas.",
"LLaVA OneVision VLLM: Beta support in Workflows."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",

View File

@ -6451,6 +6451,7 @@ export type components = {
* @description Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
* The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
* The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
* If the fade size is 0, the mask is returned as-is.
*/
ExpandMaskWithFadeInvocation: {
/**

View File

@ -1 +1 @@
__version__ = "5.9.0rc2"
__version__ = "5.9.0"

View File

@ -71,7 +71,7 @@ def create_stripped_model(original_model_path: Path, stripped_model_path: Path)
print(f"Created clone of {original.name} at {stripped.path}")
for component_path in stripped.component_paths():
original_state_dict = ModelOnDisk.load_state_dict(component_path)
original_state_dict = stripped.load_state_dict(component_path)
stripped_state_dict = strip(original_state_dict) # type: ignore
with open(component_path, "w") as f:
json.dump(stripped_state_dict, f, indent=4)