-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_vlm.py
More file actions
150 lines (113 loc) · 4.56 KB
/
test_vlm.py
File metadata and controls
150 lines (113 loc) · 4.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer
import numpy as np
from PIL import Image
import json
import copy
import os
import torch.nn.functional as F
import random
import re
CIFAR_BATCH_SIZE = 128
LM_BATCH_SIZE = 32
VL_BATCH_SIZE = 16
MAX_LENGTH = 128
HIDDEN_SIZE = 768
NUM_EPOCHS = 1
IMG_PATCH = '<img>'
NUM_IMG_TOKEN = 32
VLM_MAX_LENGTH = 32
# Function to set random seed
def set_seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def transform_fn(is_train):
if is_train:
return transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
else:
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Do not change
class LLaVADataset(Dataset):
def __init__(self, json_file, img_path, tokenizer, is_train):
super().__init__()
self.transform = transform_fn(is_train)
self.json_file = json_file
self.tokenizer = tokenizer
self.img_path = img_path
self.ignore_idx = -100
self.begin_signal = tokenizer.bos_token
self.end_signal = tokenizer.eos_token
with open(self.json_file) as json_file:
data = json.load(json_file)
if is_train:
data = data[:1000]
else:
data = data[1000:]
self.data = data
def preprocess(self, conversation):
question = self.begin_signal + "human: " + conversation[0]['value'] + self.end_signal
answer = self.begin_signal + "assistant: " + conversation[1]['value'] + self.end_signal
tokenized_q = self.tokenizer(question, return_tensors="pt")
combined_qa = question + answer
tokenized_qa = self.tokenizer(combined_qa, padding="max_length", truncation=True,
max_length=VLM_MAX_LENGTH, return_tensors="pt")
input_ids = tokenized_qa.input_ids[0]
label = copy.deepcopy(input_ids)
len_of_q = len(tokenized_q.input_ids[0])
label[:len_of_q] = self.ignore_idx
len_of_pad = tokenized_qa.input_ids.eq(self.tokenizer.pad_token_id).sum().item()
label[-len_of_pad:] = self.ignore_idx
return input_ids, label
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
meta = self.data[idx]
image_id = meta['image']
image = Image.open(os.path.join(self.img_path, image_id)).convert('RGB')
image = self.transform(image)
conversation = meta['conversation']
input_id, label = self.preprocess(conversation)
return dict(image=image, input_ids=input_id, label=label)
# Function to calculate perplexity
def calculate_perplexity(logits, targets):
loss = F.cross_entropy(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1), reduction='mean')
perplexity = torch.exp(loss).item()
return perplexity
# Main function to evaluate logits
def evaluate(logits_filename, json_file, img_path):
set_seed()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_tokens(IMG_PATCH, special_tokens=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
testset = LLaVADataset(json_file, img_path, tokenizer, is_train=False)
test_llava_loader = DataLoader(testset, batch_size=VL_BATCH_SIZE, shuffle=False)
try:
# Ensure file name is valid
assert re.match(r"\d{8}\.npy", logits_filename), "File name must be an 8-digit student ID followed by '.npy'."
# Load logits
logits = np.load(logits_filename)
logits = torch.from_numpy(logits)
assert logits.shape == (len(test_llava_loader.dataset), VLM_MAX_LENGTH, 50257), f"Logits shape mismatch: expected ({len(test_llava_loader.dataset)}, {VLM_MAX_LENGTH}, 50257)."
targets = torch.cat([target['label'] for target in test_llava_loader]).cpu()
# Calculate perplexity
perplexity = calculate_perplexity(logits[:, :-1], targets[:, 1:])
except AssertionError as e:
perplexity = 1000
print(f"Evaluation failed: {e}")
print(f'{logits_filename[:-4]} - Perplexity: {round(perplexity)}')