-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_arch.py
More file actions
174 lines (143 loc) · 7.04 KB
/
model_arch.py
File metadata and controls
174 lines (143 loc) · 7.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import torch
from torchvision.models import resnet50
from typing import Optional
class KNNClassifier(torch.nn.Module):
"""A simple K-NN classifier that acts as a classification head.
This module stores a feature bank and labels and performs K-NN
classification on input feature vectors (the ResNet "last layer"
features). It does not contain learnable parameters.
API:
- add_samples(features, labels): add feature vectors to the bank.
- reset(): clears the bank.
- predict(features): returns predicted class indices.
- predict_proba(features): returns class probability distributions.
Args:
feature_dim: dimensionality of input feature vectors
num_classes: number of target classes
k: number of neighbors to use
distance: 'cosine' or 'l2'
"""
def __init__(self, feature_dim: int, num_classes: int, k: int = 5, distance: str = "cosine", device: Optional[torch.device] = None):
super().__init__()
self.feature_dim = feature_dim
self.num_classes = num_classes
self.k = k
if distance not in ("cosine", "l2"):
raise ValueError("distance must be 'cosine' or 'l2'")
self.distance = distance
self.device = device or torch.device("cpu")
# Use buffers so tensors move with model.to(device)
self.register_buffer("features", torch.empty((0, feature_dim), dtype=torch.float32))
self.register_buffer("labels", torch.empty((0,), dtype=torch.long))
def reset(self):
"""Clear the feature bank."""
self.features = torch.empty((0, self.feature_dim), dtype=torch.float32, device=self.device)
self.labels = torch.empty((0,), dtype=torch.long, device=self.device)
@torch.no_grad()
def add_samples(self, new_features: torch.Tensor, new_labels: torch.Tensor):
"""Add new feature vectors and corresponding labels to the bank.
new_features: (N, D)
new_labels: (N,)
"""
new_features = new_features.to(self.device).float()
new_labels = new_labels.to(self.device).long()
if new_features.ndim == 1:
new_features = new_features.unsqueeze(0)
if new_labels.ndim == 0:
new_labels = new_labels.unsqueeze(0)
if new_features.size(1) != self.feature_dim:
raise ValueError(f"feature dimension mismatch: expected {self.feature_dim}, got {new_features.size(1)}")
if self.features.numel() == 0:
self.features = new_features
self.labels = new_labels
else:
self.features = torch.cat((self.features, new_features), dim=0)
self.labels = torch.cat((self.labels, new_labels), dim=0)
@torch.no_grad()
def _pairwise_similarity(self, x: torch.Tensor) -> torch.Tensor:
"""Return similarity matrix between x and bank features.
If distance == 'cosine', returns cosine similarity (higher better).
If distance == 'l2', returns negative L2 distance (higher better).
"""
x = x.to(self.device).float()
if self.features.numel() == 0:
raise ValueError("Feature bank is empty. Add samples before calling predict().")
if self.distance == "cosine":
x_norm = torch.nn.functional.normalize(x, dim=1)
bank_norm = torch.nn.functional.normalize(self.features, dim=1)
# (B, D) @ (D, N) -> (B, N)
return x_norm @ bank_norm.t()
else:
# L2 distance -> convert to negative distances so higher is better
# (x - bank)^2 = x^2 + bank^2 - 2 x@bank
x2 = (x ** 2).sum(dim=1, keepdim=True) # (B,1)
b2 = (self.features ** 2).sum(dim=1).unsqueeze(0) # (1,N)
cross = 2.0 * (x @ self.features.t()) # (B,N)
d2 = x2 + b2 - cross
return -d2
@torch.no_grad()
def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
"""Return class probabilities for input features.
x: (B, D) feature vectors (last layer of encoder)
returns: (B, num_classes) probabilities
"""
if x.ndim == 1:
x = x.unsqueeze(0)
sims = self._pairwise_similarity(x) # (B, N_bank)
k = min(self.k, sims.size(1))
topk = torch.topk(sims, k=k, dim=1).indices # (B, k)
# Gather labels
neighbor_labels = self.labels[topk] # (B, k)
# Count votes per class
B = neighbor_labels.size(0)
counts = torch.zeros((B, self.num_classes), device=self.device)
# scatter_add 1s at neighbor label positions
idx = neighbor_labels.unsqueeze(-1) # (B, k, 1)
one_hot = torch.zeros((B, k, self.num_classes), device=self.device)
one_hot.scatter_(2, idx, 1)
counts = one_hot.sum(dim=1)
probs = counts / counts.sum(dim=1, keepdim=True).clamp_min(1.0)
return probs
@torch.no_grad()
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""Return predicted class indices for input features."""
probs = self.predict_proba(x)
return torch.argmax(probs, dim=1)
def get_resnet50_model(pretrained: bool = False, output_dim: int = 1000, attach_knn: bool = True, knn_k: int = 5, knn_distance: str = "cosine"):
"""Return a ResNet-50 encoder. By default the final `fc` is replaced with an Identity
so the model returns the last-layer features (the input to a typical classification head).
If attach_knn=True a `knn_classifier` attribute (KNNClassifier) is attached to the returned model.
Args:
pretrained: load ImageNet weights if True
output_dim: number of classes (used only to size the KNN head)
attach_knn: whether to attach a KNNClassifier instance as `model.knn_classifier`
knn_k: number of neighbors for the KNN head
knn_distance: 'cosine' or 'l2'
Returns:
model: torchvision ResNet model whose `fc` is an Identity (so forward returns features)
"""
model = resnet50(pretrained=pretrained)
feature_dim = model.fc.in_features
# replace final fc with identity so forward() returns feature vectors
model.fc = torch.nn.Identity()
if attach_knn:
knn = KNNClassifier(feature_dim=feature_dim, num_classes=output_dim, k=knn_k, distance=knn_distance, device=torch.device("cpu"))
model.knn_classifier = knn
return model
if __name__ == "__main__":
# Quick demonstration of how to use the encoder + KNN head
model = get_resnet50_model(pretrained=False, output_dim=10)
print("ResNet encoder with KNN head:")
print(model)
# Create dummy data and show feature extraction -> KNN prediction flow
dummy = torch.randn(2, 3, 224, 224)
with torch.no_grad():
feats = model(dummy) # (2, feature_dim)
# Create a small feature bank of 6 samples for 3 classes
bank_feats = torch.randn(6, feats.size(1))
bank_labels = torch.tensor([0, 0, 1, 1, 2, 2], dtype=torch.long)
model.knn_classifier.add_samples(bank_feats, bank_labels)
preds = model.knn_classifier.predict(feats)
probs = model.knn_classifier.predict_proba(feats)
print("preds:", preds)
print("probs:", probs)