Add unit test to test that isinstance(...) behaves as expected with custom module types.

This commit is contained in:
Ryan Dick 2024-12-26 18:45:56 +00:00
parent a8b2c4c3d2
commit b0b699a01f

View File

@ -74,6 +74,17 @@ def layer_to_device_via_state_dict(layer: torch.nn.Module, device: str):
layer.load_state_dict(state_dict, assign=True)
@parameterize_all_layer_types
def test_isinstance(orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool):
"""Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer."""
orig_type = type(orig_layer)
apply_custom_layers_to_model(orig_layer)
assert isinstance(orig_layer, orig_type)
assert type(orig_layer) is not orig_type
@parameterize_all_devices
@parameterize_all_layer_types
def test_state_dict(device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool):