proposed fix to work on mps systems

This commit is contained in:
Lincoln Stein 2022-10-12 11:08:27 -04:00
parent b537e92789
commit aa6aa68753
2 changed files with 6 additions and 3 deletions

View File

@ -195,15 +195,17 @@ class ModelCache(object):
torch.cuda.empty_cache()
def _model_to_cpu(self,model):
if self._has_cuda():
if self.device != 'cpu':
model.cond_stage_model.device = 'cpu'
model.first_stage_model.to('cpu')
model.cond_stage_model.to('cpu')
model.model.to('cpu')
return model.to('cpu')
return model.to('cpu')
else:
return model
def _model_from_cpu(self,model):
if self._has_cuda():
if self.device != 'cpu':
model.to(self.device)
model.first_stage_model.to(self.device)
model.cond_stage_model.to(self.device)

View File

@ -154,6 +154,7 @@ def main_loop(gen, opt, infile):
elif subcommand.startswith('switch'):
model_name = command.replace('!switch ','',1)
gen.set_model(model_name)
completer.add_history(command)
continue
elif subcommand.startswith('models'):