-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdemo.py
More file actions
executable file
·149 lines (118 loc) · 5.8 KB
/
demo.py
File metadata and controls
executable file
·149 lines (118 loc) · 5.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
import os
import random
import json
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr
from emu_models.modeling_emu_clip import Emu_clip
import utils
from PIL import Image
import torch.nn.functional as F
def parse_args():
parser = argparse.ArgumentParser(description="Image Retrieval Demo")
parser.add_argument("--model-cfg", type=str, default='emu_models/Emu-8B_frozenvis_cliploss.json',
help="path to model configuration file")
parser.add_argument("--checkpoint", type=str, required=True,
help="path to model checkpoint")
parser.add_argument("--image-feat-path", type=str, required=True,
help="path to pre-computed image features")
parser.add_argument("--annotation-path", type=str, required=True,
help="path to image annotation JSON file")
parser.add_argument("--image-root", type=str, required=True,
help="root directory of images")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
args = parser.parse_args()
return args
class Search:
def __init__(self, model, vis_processor, device, image_feat_path=None, annotation_path=None, image_root=None):
self.device = device
self.model = model
self.vis_processor = vis_processor
if image_feat_path is None:
raise ValueError("Please provide image_feat_path for pre-computed image features")
if annotation_path is None:
raise ValueError("Please provide annotation_path for image annotations")
if image_root is None:
raise ValueError("Please provide image_root for image directory")
image_features = torch.load(image_feat_path, map_location='cpu')
image_features = self.model.ln_visual(image_features)
image_features = self.model.proj_head(image_features[:,:]).float()
self.image_embeds = F.normalize(self.model.decoder.lm.vision_head(image_features), dim=-1)
self.urls = json.load(open(annotation_path))
self.root = image_root
def ret(self, text_input, image=None):
if image is None:
image_features = None
image_placeholder = ''
else:
image = image.unsqueeze(0)
image_features = self.model.visual.forward_features(image)
image_features = self.model.ln_visual(image_features)
image_features = self.model.proj_head(image_features[:,0:1,:])
image_placeholder = "[IMG]" + "<image>" * 1 + "[/IMG]"
prompt = [image_placeholder + text_input]
input_tokens = self.model.decoder.tokenizer(prompt, padding="longest", return_tensors="pt", add_special_tokens=True, truncation=True, max_length=70).to(self.device)
text_input = input_tokens.input_ids
input_mask = input_tokens.attention_mask
text_embed = self.model.decoder.gen(image_features, text_input=text_input, text_mask=input_mask)
scores = text_embed @ self.image_embeds.t()
idx = scores.argmax().item()
name = os.path.join(self.root, self.urls[idx]['image_id'])
return name
from llava.processors.blip_processors import BlipCaptionProcessor, BlipImageTrainProcessor
text_processor = BlipCaptionProcessor()
image_processor = BlipImageTrainProcessor(image_size=224)
print('Initializing Image Retrieval System')
args = parse_args()
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
print("Creating model", flush=True)
with open(args.model_cfg, "r", encoding="utf8") as f:
model_cfg = json.load(f)
print(f"=====> model_cfg: {model_cfg}")
dic = {'loss':'clip'}
model_args = argparse.Namespace(**dic)
model = Emu_clip(**model_cfg, cast_dtype=torch.bfloat16, args=model_args)
ckpt = torch.load(args.checkpoint, map_location="cpu")
msg = model.load_state_dict(ckpt, strict=False)
print(msg)
model = model.to(device)
print("### Total Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
search = Search(model, image_processor, device=device,
image_feat_path=args.image_feat_path,
annotation_path=args.annotation_path,
image_root=args.image_root)
print('Initialization Finished')
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return gr.update(value=None, interactive=True), gr.update(placeholder='', interactive=True),gr.update(value=None, interactive=False), gr.update(value="Upload & Start Chat", interactive=True)
def upload_img(gr_img):
if isinstance(gr_img, str):
raw_image = Image.open(gr_img).convert('RGB')
gr_img = image_processor(raw_image).unsqueeze(0).to(device)
elif isinstance(gr_img, Image.Image):
raw_image = gr_img
gr_img = image_processor(raw_image).unsqueeze(0).to(device)
elif isinstance(gr_img, torch.Tensor):
if len(gr_img.shape) == 3:
gr_img = gr_img.unsqueeze(0)
else:
gr_img = None
return gr.update(interactive=False)
def gradio_ask(image, user_message, out_image):
if len(user_message) == 0:
return image, gr.update(interactive=True, placeholder='Input should not be empty!'), out_image
if image is None:
name = search.ret(user_message, image)
out_image = Image.open(name)
else:
name = search.ret(user_message, image)
out_image = Image.open(name)
return image, gr.update(interactive=False), out_image
title = """<h1 align="center">Image Retrieval Demo</h1>"""
description = """<h3>This is a text-to-image retrieval demo. Enter a text query to search for relevant images!</h3>"""
article = """<p>This demo allows you to search for images using natural language queries.</p>