From 2144d21f80ec140f01dfe0e04f0540baa738a915 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 6 Dec 2024 21:49:24 +0000 Subject: [PATCH] Maintain a read-only CPU state dict copy in CachedModelWithPartialLoad. --- .../cached_model_with_partial_load.py | 70 ++++++++++++------- .../test_cached_model_with_partial_load.py | 48 +++++++++++-- 2 files changed, 89 insertions(+), 29 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index 13cd442adb..21503db08a 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -6,6 +6,18 @@ from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_con from invokeai.backend.util.calc_tensor_size import calc_tensor_size +def set_nested_attr(obj: object, attr: str, value: object): + """A helper function that extends setattr() to support nested attributes. + + Example: + set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight) + """ + attrs = attr.split(".") + for attr in attrs[:-1]: + obj = getattr(obj, attr) + setattr(obj, attrs[-1], value) + + class CachedModelWithPartialLoad: """A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device. @@ -17,6 +29,9 @@ class CachedModelWithPartialLoad: self._model = model self._compute_device = compute_device + # A CPU read-only copy of the model's state dict. + self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict() + # Monkey-patch the model to add autocasting to the model's forward method. add_autocast_to_module_forward(model, compute_device) @@ -32,8 +47,8 @@ class CachedModelWithPartialLoad: def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None: """Get a read-only copy of the model's state dict in RAM.""" - # TODO(ryand): Document this better and implement it. - return None + # TODO(ryand): Document this better. + return self._cpu_state_dict def total_bytes(self) -> int: """Get the total size (in bytes) of all the weights in the model.""" @@ -55,6 +70,7 @@ class CachedModelWithPartialLoad: """Unload all weights from VRAM.""" return self.partial_unload_from_vram(self.total_bytes()) + @torch.no_grad() def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: """Load more weights into VRAM without exceeding vram_bytes_to_load. @@ -63,32 +79,39 @@ class CachedModelWithPartialLoad: """ vram_bytes_loaded = 0 - # TODO(ryand): Should we use self._model.apply(...) instead and move modules around instead of moving tensors? - # This way we don't have to use the private _apply() method. - def to_vram(t: torch.Tensor): - nonlocal vram_bytes_loaded - + # TODO(ryand): Iterate over buffers too? + for key, param in self._model.named_parameters(): # Skip parameters that are already on the compute device. - if t.device.type == self._compute_device.type: - return t + if param.device.type == self._compute_device.type: + continue # Check the size of the parameter. - param_size = calc_tensor_size(t) + param_size = calc_tensor_size(param) if vram_bytes_loaded + param_size > vram_bytes_to_load: # TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really # worth continuing to search for a smaller parameter that would fit? - return t + continue + + # Copy the parameter to the compute device. + # We use the 'overwrite' strategy from torch.nn.Module._apply(). + # TODO(ryand): For some edge cases (e.g. quantized models?), we may need to support other strategies (e.g. + # swap). + assert isinstance(param, torch.nn.Parameter) + assert param.is_leaf + out_param = torch.nn.Parameter(param.to(self._compute_device, copy=True), requires_grad=param.requires_grad) + set_nested_attr(self._model, key, out_param) + # We did not port the param.grad handling from torch.nn.Module._apply(), because we do not expect to be + # handling gradients. We assert that this assumption is true. + assert param.grad is None vram_bytes_loaded += param_size - return t.to(self._compute_device) - - self._model._apply(to_vram) if self._cur_vram_bytes is not None: self._cur_vram_bytes += vram_bytes_loaded return vram_bytes_loaded + @torch.no_grad() def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int: """Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded. @@ -97,19 +120,18 @@ class CachedModelWithPartialLoad: """ vram_bytes_freed = 0 - def from_vram(t: torch.Tensor): - nonlocal vram_bytes_freed - + # TODO(ryand): Iterate over buffers too? + for key, param in self._model.named_parameters(): if vram_bytes_freed >= vram_bytes_to_free: - return t + break - if t.device.type != self._compute_device.type: - return t + if param.device.type != self._compute_device.type: + continue - vram_bytes_freed += calc_tensor_size(t) - return t.to("cpu") - - self._model._apply(from_vram) + # Create a new parameter, but inject the existing CPU tensor into it. + out_param = torch.nn.Parameter(self._cpu_state_dict[key], requires_grad=param.requires_grad) + set_nested_attr(self._model, key, out_param) + vram_bytes_freed += calc_tensor_size(param) if self._cur_vram_bytes is not None: self._cur_vram_bytes -= vram_bytes_freed diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py index 637b2719ba..dd23b527d8 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py @@ -4,6 +4,7 @@ import torch from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) +from invokeai.backend.util.calc_tensor_size import calc_tensor_size from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule parameterize_mps_and_cuda = pytest.mark.parametrize( @@ -33,41 +34,53 @@ def test_cached_model_total_bytes(device: str): @parameterize_mps_and_cuda def test_cached_model_cur_vram_bytes(device: str): model = DummyModule() + # Model starts in CPU memory. cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) assert cached_model.cur_vram_bytes() == 0 + # Full load the model into VRAM. cached_model.full_load_to_vram() assert cached_model.cur_vram_bytes() > 0 assert cached_model.cur_vram_bytes() == cached_model.total_bytes() + assert all(p.device.type == device for p in model.parameters()) @parameterize_mps_and_cuda def test_cached_model_partial_load(device: str): model = DummyModule() + # Model starts in CPU memory. cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) model_total_bytes = cached_model.total_bytes() assert cached_model.cur_vram_bytes() == 0 + # Partially load the model into VRAM. target_vram_bytes = int(model_total_bytes * 0.6) loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes) assert loaded_bytes > 0 assert loaded_bytes < model_total_bytes assert loaded_bytes == cached_model.cur_vram_bytes() + assert loaded_bytes == sum(calc_tensor_size(p) for p in model.parameters() if p.device.type == device) @parameterize_mps_and_cuda def test_cached_model_partial_unload(device: str): model = DummyModule() - model.to(device=torch.device(device)) + # Model starts in CPU memory. cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Full load the model into VRAM. + cached_model.full_load_to_vram() assert cached_model.cur_vram_bytes() == model_total_bytes + # Partially unload the model from VRAM. bytes_to_free = int(model_total_bytes * 0.4) freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free) assert freed_bytes >= bytes_to_free assert freed_bytes < model_total_bytes assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes() + assert freed_bytes == sum(calc_tensor_size(p) for p in model.parameters() if p.device.type == "cpu") @parameterize_mps_and_cuda @@ -84,7 +97,7 @@ def test_cached_model_full_load(device: str): assert loaded_bytes > 0 assert loaded_bytes == model_total_bytes assert loaded_bytes == cached_model.cur_vram_bytes() - assert all(p.device.type == device for p in cached_model.model.parameters()) + assert all(p.device.type == device for p in model.parameters()) @parameterize_mps_and_cuda @@ -109,11 +122,11 @@ def test_cached_model_full_load_from_partial(device: str): assert loaded_bytes_2 < model_total_bytes assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes() assert loaded_bytes + loaded_bytes_2 == model_total_bytes - assert all(p.device.type == device for p in cached_model.model.parameters()) + assert all(p.device.type == device for p in model.parameters()) @parameterize_mps_and_cuda -def test_cached_model_full_unload(device: str): +def test_cached_model_full_unload_from_partial(device: str): model = DummyModule() cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) @@ -133,4 +146,29 @@ def test_cached_model_full_unload(device: str): assert unloaded_bytes > 0 assert unloaded_bytes == loaded_bytes assert cached_model.cur_vram_bytes() == 0 - assert all(p.device.type == "cpu" for p in cached_model.model.parameters()) + assert all(p.device.type == "cpu" for p in model.parameters()) + + +@parameterize_mps_and_cuda +def test_cached_model_get_cpu_state_dict(device: str): + model = DummyModule() + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + + # Model starts in CPU memory. + assert cached_model.cur_vram_bytes() == 0 + + # The CPU state dict can be accessed and has the expected properties. + cpu_state_dict = cached_model.get_cpu_state_dict() + assert cpu_state_dict is not None + assert len(cpu_state_dict) == len(model.state_dict()) + assert all(p.device.type == "cpu" for p in cpu_state_dict.values()) + + # Full load the model into VRAM. + cached_model.full_load_to_vram() + assert cached_model.cur_vram_bytes() == cached_model.total_bytes() + + # The CPU state dict is still available, and still on the CPU. + cpu_state_dict = cached_model.get_cpu_state_dict() + assert cpu_state_dict is not None + assert len(cpu_state_dict) == len(model.state_dict()) + assert all(p.device.type == "cpu" for p in cpu_state_dict.values())