-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
74 lines (63 loc) · 2.29 KB
/
server.py
File metadata and controls
74 lines (63 loc) · 2.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import openfhe
from client import ClientParameters
from enroller import EnrollerDatabase
class Server:
def compute_identities(
self,
params: ClientParameters,
database: EnrollerDatabase,
query: openfhe.Ciphertext,
) -> list[openfhe.Ciphertext]:
thresholds = []
for batch in database.database:
score = self._compute_score(params.cc, params.embedding_dim, query, batch)
threshold = self._compute_threshold(params.cc, params.similarity_threshold, score)
thresholds.append(threshold)
return thresholds
def compute_membership(
self,
params: ClientParameters,
database: EnrollerDatabase,
query: openfhe.Ciphertext,
) -> openfhe.Ciphertext:
thresholds = self.compute_identities(params, database, query)
for i in range(len(thresholds)):
thresholds[i] = params.cc.EvalSum(thresholds[i], params.batch_size)
for i in range(1, len(thresholds)):
params.cc.EvalAddInPlace(thresholds[0], thresholds[i])
return thresholds[0]
def _compute_score(
self,
cc: openfhe.CryptoContext,
embedding_dim: int,
query: openfhe.Ciphertext,
batch: list[openfhe.Ciphertext],
) -> openfhe.Ciphertext:
scores = []
m = cc.GetCyclotomicOrder()
rotation_precomp = cc.EvalFastRotationPrecompute(query)
for i in range(embedding_dim):
ct_query_i = cc.EvalFastRotation(query, i, m, rotation_precomp)
s = cc.EvalMultNoRelin(batch[i], ct_query_i)
scores.append(s)
for i in range(1, embedding_dim):
cc.EvalAddInPlace(scores[0], scores[i])
cc.RelinearizeInPlace(scores[0])
cc.RescaleInPlace(scores[0])
return scores[0]
def _compute_threshold(
self,
cc: openfhe.CryptoContext,
threshold: float,
score: openfhe.Ciphertext,
) -> openfhe.Ciphertext:
x = cc.EvalChebyshevFunction(
lambda x: 1.0 if float(x) >= threshold else -1.0,
score,
a=-1.0,
b=1.0,
degree=59,
)
xx = cc.EvalPoly(x, [0, 315 / 128, 0, -420 / 128, 0, 378 / 128, 0, -180 / 128, 0, 35 / 128])
xxx = cc.EvalAdd(xx, 1)
return xxx