mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2025-01-08 11:57:36 +08:00
prevent crash when switching to an invalid model
This commit is contained in:
parent
b17ca0a5e7
commit
71ee44a827
@ -802,6 +802,10 @@ class Generate:
|
||||
|
||||
# the model cache does the loading and offloading
|
||||
cache = self.model_cache
|
||||
if not cache.valid_model(model_name):
|
||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return self.model
|
||||
|
||||
cache.print_vram_usage()
|
||||
|
||||
# have to get rid of all references to model in order
|
||||
|
@ -41,15 +41,22 @@ class ModelCache(object):
|
||||
self.stack = [] # this is an LRU FIFO
|
||||
self.current_model = None
|
||||
|
||||
def valid_model(self, model_name:str)->bool:
|
||||
'''
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
'''
|
||||
return model_name in self.config
|
||||
|
||||
def get_model(self, model_name:str):
|
||||
'''
|
||||
Given a model named identified in models.yaml, return
|
||||
the model object. If in RAM will load into GPU VRAM.
|
||||
If on disk, will load from there.
|
||||
'''
|
||||
if model_name not in self.config:
|
||||
if not self.valid_model(model_name):
|
||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return None
|
||||
return self.current_model
|
||||
|
||||
if self.current_model != model_name:
|
||||
if model_name not in self.models: # make room for a new one
|
||||
|
Loading…
Reference in New Issue
Block a user