reduce VRAM memory usage by half during model loading

* This moves the call to half() before model.to(device) to avoid GPU
copy of full model. Improves speed and reduces memory usage dramatically

* This fix contributed by @mh-dm (Mihai)
This commit is contained in:
Lincoln Stein 2022-09-10 10:02:43 -04:00
parent 99122708ca
commit 5c43988862

View File

@ -536,9 +536,6 @@ class Generate:
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.to(self.device)
model.eval()
if self.full_precision:
print(
@ -549,6 +546,8 @@ class Generate:
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
)
model.half()
model.to(self.device)
model.eval()
# usage statistics
toc = time.time()