mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-07 03:17:05 +08:00
Strategize slicing based on free [V]RAM (#2572)
Strategize slicing based on free [V]RAM when not using xformers. Free [V]RAM is evaluated at every generation. When there's enough memory, the entire generation occurs without slicing. If there is not enough free memory, we use diffusers' sliced attention.
This commit is contained in:
parent
7c86130a3d
commit
9eed1919c2
@ -223,7 +223,7 @@ class Generate:
|
||||
self.model_name = model or fallback
|
||||
|
||||
# for VRAM usage statistics
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# gets rid of annoying messages about random seed
|
||||
@ -592,20 +592,24 @@ class Generate:
|
||||
self.print_cuda_stats()
|
||||
return results
|
||||
|
||||
def clear_cuda_cache(self):
|
||||
def gather_cuda_stats(self):
|
||||
if self._has_cuda():
|
||||
self.max_memory_allocated = max(
|
||||
self.max_memory_allocated,
|
||||
torch.cuda.max_memory_allocated()
|
||||
torch.cuda.max_memory_allocated(self.device)
|
||||
)
|
||||
self.memory_allocated = max(
|
||||
self.memory_allocated,
|
||||
torch.cuda.memory_allocated()
|
||||
torch.cuda.memory_allocated(self.device)
|
||||
)
|
||||
self.session_peakmem = max(
|
||||
self.session_peakmem,
|
||||
torch.cuda.max_memory_allocated()
|
||||
torch.cuda.max_memory_allocated(self.device)
|
||||
)
|
||||
|
||||
def clear_cuda_cache(self):
|
||||
if self._has_cuda():
|
||||
self.gather_cuda_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def clear_cuda_stats(self):
|
||||
@ -614,6 +618,7 @@ class Generate:
|
||||
|
||||
def print_cuda_stats(self):
|
||||
if self._has_cuda():
|
||||
self.gather_cuda_stats()
|
||||
print(
|
||||
'>> Max VRAM used for this generation:',
|
||||
'%4.2fG.' % (self.max_memory_allocated / 1e9),
|
||||
|
@ -301,10 +301,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
textual_inversion_manager=self.textual_inversion_manager
|
||||
)
|
||||
|
||||
self._enable_memory_efficient_attention()
|
||||
|
||||
|
||||
def _enable_memory_efficient_attention(self):
|
||||
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
@ -317,7 +315,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
|
||||
pass
|
||||
else:
|
||||
self.enable_attention_slicing(slice_size='max')
|
||||
if self.device.type == 'cpu' or self.device.type == 'mps':
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == 'cuda':
|
||||
mem_free, _ = torch.cuda.mem_get_info(self.device)
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
max_size_required_for_baddbmm = \
|
||||
16 * \
|
||||
latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \
|
||||
bytes_per_element_needed_for_baddbmm_duplication
|
||||
if max_size_required_for_baddbmm > (mem_free * 3.3 / 4.0): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size='max')
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
|
||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
@ -377,6 +392,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
if additional_guidance is None:
|
||||
|
Loading…
Reference in New Issue
Block a user