Add a unit test for a LoRA patch applied to a quantized linear layer with weights streamed from CPU to GPU.

This commit is contained in:
Ryan Dick 2024-12-29 17:14:55 +00:00
parent a8bef59699
commit 52fc5a64d4

View File

@ -484,3 +484,47 @@ def test_quantized_linear_sidecar_patches(
output_linear_patched = linear_layer_custom(input)
output_quantized_patched = quantized_linear_layer_custom(input)
assert torch.allclose(output_linear_patched, output_quantized_patched, rtol=0.2, atol=0.2)
@parameterize_cuda_and_mps
def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device(
device: str,
quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module],
patch_under_test: PatchUnderTest,
):
"""Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and
when the layer is on the CPU and the patches are autocasted to the device.
"""
patches, input = patch_under_test
_, quantized_linear_layer = quantized_linear_layer_under_test
# Move everything to the device.
layer_to_device_via_state_dict(quantized_linear_layer, device)
input = input.to(torch.device(device))
# Wrap the quantized linear layer in a custom layer and add the patch to it.
quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer)
for patch, weight in patches:
patch.to(torch.device(device))
quantized_linear_layer_custom.add_patch(patch, weight)
# Run inference with the custom layer on the device.
expected_output = quantized_linear_layer_custom(input)
# Move the custom layer to the CPU.
layer_to_device_via_state_dict(quantized_linear_layer_custom, "cpu")
# Move the patches to the CPU.
quantized_linear_layer_custom.clear_patches()
for patch, weight in patches:
patch.to(torch.device("cpu"))
quantized_linear_layer_custom.add_patch(patch, weight)
# Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to
# the device.
autocast_output = quantized_linear_layer_custom(input)
assert autocast_output.device.type == device
# Assert that the outputs with and without autocasting are the same.
assert torch.allclose(expected_output, autocast_output, atol=1e-6)