Use a fixture to parameterize tests in test_all_custom_modules.py so that a fresh instance of the layer under test is initialized for each test.

This commit is contained in:
Ryan Dick 2024-12-26 19:41:25 +00:00
parent b0b699a01f
commit 9692a36dd6

View File

@ -46,19 +46,43 @@ parameterize_cuda_and_mps = pytest.mark.parametrize(
],
)
parameterize_all_layer_types = pytest.mark.parametrize(
("orig_layer", "layer_input", "supports_cpu_inference"),
[
(torch.nn.Linear(8, 16), torch.randn(1, 8), True),
(torch.nn.Conv1d(8, 16, 3), torch.randn(1, 8, 5), True),
(torch.nn.Conv2d(8, 16, 3), torch.randn(1, 8, 5, 5), True),
(torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True),
(torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True),
(build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True),
(build_linear_8bit_lt_layer(), torch.randn(1, 32), False),
(build_linear_nf4_layer(), torch.randn(1, 64), False),
],
LayerUnderTest = tuple[torch.nn.Module, torch.Tensor, bool]
@pytest.fixture(
params=[
"linear",
"conv1d",
"conv2d",
"group_norm",
"embedding",
"linear_with_ggml_quantized_tensor",
"invoke_linear_8_bit_lt",
"invoke_linear_nf4",
]
)
def layer_under_test(request: pytest.FixtureRequest) -> LayerUnderTest:
"""A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test."""
layer_type = request.param
if layer_type == "linear":
return (torch.nn.Linear(8, 16), torch.randn(1, 8), True)
elif layer_type == "conv1d":
return (torch.nn.Conv1d(8, 16, 3), torch.randn(1, 8, 5), True)
elif layer_type == "conv2d":
return (torch.nn.Conv2d(8, 16, 3), torch.randn(1, 8, 5, 5), True)
elif layer_type == "group_norm":
return (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True)
elif layer_type == "embedding":
return (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True)
elif layer_type == "linear_with_ggml_quantized_tensor":
return (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True)
elif layer_type == "invoke_linear_8_bit_lt":
return (build_linear_8bit_lt_layer(), torch.randn(1, 32), False)
elif layer_type == "invoke_linear_nf4":
return (build_linear_nf4_layer(), torch.randn(1, 64), False)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")
def layer_to_device_via_state_dict(layer: torch.nn.Module, device: str):
@ -74,9 +98,9 @@ def layer_to_device_via_state_dict(layer: torch.nn.Module, device: str):
layer.load_state_dict(state_dict, assign=True)
@parameterize_all_layer_types
def test_isinstance(orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool):
def test_isinstance(layer_under_test: LayerUnderTest):
"""Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer."""
orig_layer, _, _ = layer_under_test
orig_type = type(orig_layer)
apply_custom_layers_to_model(orig_layer)
@ -86,9 +110,10 @@ def test_isinstance(orig_layer: torch.nn.Module, layer_input: torch.Tensor, supp
@parameterize_all_devices
@parameterize_all_layer_types
def test_state_dict(device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool):
def test_state_dict(device: str, layer_under_test: LayerUnderTest):
"""Test that .state_dict() behaves the same on the original layer and the wrapped layer."""
orig_layer, _, _ = layer_under_test
# Get the original layer on the test device.
orig_layer.to(device)
orig_state_dict = orig_layer.state_dict()
@ -108,11 +133,10 @@ def test_state_dict(device: str, orig_layer: torch.nn.Module, layer_input: torch
@parameterize_all_devices
@parameterize_all_layer_types
def test_load_state_dict(
device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool
):
def test_load_state_dict(device: str, layer_under_test: LayerUnderTest):
"""Test that .load_state_dict() behaves the same on the original layer and the wrapped layer."""
orig_layer, _, _ = layer_under_test
orig_layer.to(device)
custom_layer = copy.deepcopy(orig_layer)
@ -138,13 +162,12 @@ def test_load_state_dict(
@parameterize_all_devices
@parameterize_all_layer_types
def test_inference_on_device(
device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool
):
def test_inference_on_device(device: str, layer_under_test: LayerUnderTest):
"""Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the
device.
"""
orig_layer, layer_input, supports_cpu_inference = layer_under_test
if device == "cpu" and not supports_cpu_inference:
pytest.skip("Layer does not support CPU inference.")
@ -164,13 +187,15 @@ def test_inference_on_device(
@parameterize_cuda_and_mps
@parameterize_all_layer_types
def test_inference_autocast_from_cpu_to_device(
device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool
):
def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: LayerUnderTest):
"""Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the
device.
"""
orig_layer, layer_input, supports_cpu_inference = layer_under_test
if device == "cpu" and not supports_cpu_inference:
pytest.skip("Layer does not support CPU inference.")
# Make sure the original layer is on the device.
layer_to_device_via_state_dict(orig_layer, device)