Update interrogate

This commit is contained in:
lnyan 2022-10-23 19:27:45 +08:00
parent 2da1f2efb8
commit 90b8128aea

View File

@ -15,7 +15,6 @@ from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from blip_model.blip import blip_decoder
from transformers import CLIPTokenizer, CLIPModel
from transformers import CLIPProcessor, CLIPModel
@ -75,6 +74,7 @@ class Interrogator:
self.text_feature_lst = [torch.load(os.path.join(data_path, f"{i}.pth")) for i in range(5)]
def get_blip(self):
from blip_model.blip import blip_decoder
blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base')
blip_model.eval()
self.blip_model = blip_model