Skip to content

Commit c47f659

Browse files
Update: evaluate MIA on the complement of rpop, instead of gpop
1 parent 3f42746 commit c47f659

1 file changed

Lines changed: 33 additions & 40 deletions

File tree

src/mia.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from scipy.stats import norm
77
from sklearn import metrics
88

9-
from src.config import get_cur_dir, set_seed
9+
from src.config import get_cur_dir, safe_assert, set_seed
1010
from src.defense import noisy_bn
1111

1212

@@ -39,21 +39,14 @@ def mia_vs_bn(exp, config) -> dict:
3939
f"{cur_dir}/{config['bns_path']}/pool/bn_{exp}_sample{sample}.bif"
4040
)
4141

42-
bn_theta_hat_ie = gum.LazyPropagation(bn_theta_hat)
43-
bn_theta_ie = gum.LazyPropagation(bn_theta)
44-
45-
# ... retrieve rpop, ...
46-
rpop = gpop[gpop[f"in-rpop-{sample}"]].iloc[:, : len(bn_theta.nodes())]
47-
4842
# try:
4943

5044
# ... and perform membership inference on gpop
5145
power_vec, auc = run_mia(
52-
bn_theta_hat_ie,
53-
bn_theta_ie,
54-
rpop,
46+
bn_theta_hat,
47+
bn_theta,
5548
gpop,
56-
gpop[f"in-pool-{sample}"],
49+
sample,
5750
eval(config["error"]),
5851
)
5952
power_res[f"power_BN_sample{sample}"] = power_vec
@@ -108,21 +101,14 @@ def mia_vs_cn(exp, config) -> pd.DataFrame:
108101
f"{cur_dir}/{config['bns_path']}/rpop/bn_{exp}_sample{sample}.bif"
109102
)
110103

111-
bn_theta_hat_ie = gum.LazyPropagation(bn_theta_hat)
112-
bn_theta_ie = gum.LazyPropagation(bn_theta)
113-
114-
# ... retrieve rpop, ...
115-
rpop = gpop[gpop[f"in-rpop-{sample}"]].iloc[:, : len(bn_theta.nodes())]
116-
117104
# try:
118105

119106
# ... and perform membership inference on gpop
120107
power_vec, auc = run_mia(
121-
bn_theta_hat_ie,
122-
bn_theta_ie,
123-
rpop,
108+
bn_theta_hat,
109+
bn_theta,
124110
gpop,
125-
gpop[f"in-pool-{sample}"],
111+
sample,
126112
eval(config["error"]),
127113
)
128114
power_res[f"power_CN_sample{sample}"] = power_vec
@@ -209,9 +195,6 @@ def find_epsilon(exp, config) -> dict:
209195
f"{cur_dir}/{config['bns_path']}/pool/bn_{exp}_sample{sample}.bif"
210196
)
211197

212-
# ... retrieve rpop, ...
213-
rpop = gpop[gpop[f"in-rpop-{sample}"]].iloc[:, : len(bn_theta.nodes())]
214-
215198
# ... get CN AUC, ...
216199
auc_cn = auc_res.loc[auc_res["sample"] == sample, "auc_cn"].values[0]
217200

@@ -226,16 +209,12 @@ def find_epsilon(exp, config) -> dict:
226209
scale = (2 * bn_theta_hat.size()) / (pool_ss * eps)
227210
bn_noisy = noisy_bn(bn_theta_hat, scale)
228211

229-
bn_noisy_ie = gum.LazyPropagation(bn_noisy)
230-
bn_theta_ie = gum.LazyPropagation(bn_theta)
231-
232212
# Perform membership inference on gpop
233213
power_vec, auc = run_mia(
234-
bn_noisy_ie,
235-
bn_theta_ie,
236-
rpop,
214+
bn_noisy,
215+
bn_theta,
237216
gpop,
238-
gpop[f"in-pool-{sample}"],
217+
sample,
239218
eval(config["error"]),
240219
)
241220

@@ -261,39 +240,53 @@ def find_epsilon(exp, config) -> dict:
261240

262241

263242
# MIA: membership inference attack
264-
def run_mia(model, baseline, rpop, gpop, ground_truth, error_vec):
243+
def run_mia(model, baseline, gpop, sample, error_vec):
244+
245+
# Create objects for inference
246+
model_ie = gum.LazyPropagation(model)
247+
baseline_ie = gum.LazyPropagation(baseline)
265248

266-
# Compute llr(x) on reference and general populations
267-
llr_ref = (
268-
rpop.apply(lambda x: get_llr(x.to_dict(), baseline, model), axis=1)
249+
# Retrieve rpop and evaluation set (i.e. the rpop complement in gpop)
250+
rpop = gpop[gpop[f"in-rpop-{sample}"]].iloc[:, : len(model.nodes())]
251+
eval_pop = gpop[~gpop[f"in-rpop-{sample}"]].iloc[:, : len(model.nodes())]
252+
ground_truth = gpop[~gpop[f"in-rpop-{sample}"]].loc[:, f"in-pool-{sample}"]
253+
254+
# Compute llr(x)'s
255+
llr_rpop = (
256+
rpop.apply(lambda x: get_llr(x.to_dict(), baseline_ie, model_ie), axis=1)
269257
.dropna()
270258
.sort_values()
271259
)
272-
llr_gen = gpop[[*rpop.columns]].apply(
273-
lambda x: get_llr(x.to_dict(), baseline, model), axis=1
260+
llr_eval = eval_pop.apply(
261+
lambda x: get_llr(x.to_dict(), baseline_ie, model_ie), axis=1
274262
)
275263

276264
power_vec = []
277265

278266
# Get the power for each error
279267
for error in error_vec:
280-
power = get_power(llr_ref, llr_gen, ground_truth, error)
268+
power = get_power(llr_rpop, llr_eval, ground_truth, error)
281269
power_vec.append(power)
282270

283271
# Compute and store AUC
284272
auc = metrics.auc(error_vec, power_vec)
285273

274+
# Debug
275+
safe_assert(len(rpop) + len(eval_pop) == len(gpop))
276+
safe_assert(len(eval_pop) == len(ground_truth))
277+
safe_assert(len(power_vec) == len(error_vec))
278+
286279
return power_vec, auc
287280

288281

289282
# Get the attack power related to a fixed error
290-
def get_power(llr_ref, llr_gen, ground_truth, error) -> float:
283+
def get_power(llr_ref, llr_eval, ground_truth, error) -> float:
291284

292285
# Compute the threshold
293286
t = np.quantile(llr_ref, 1 - error).item()
294287

295288
# Test: L(x) > t => reject H_0 => assign `x` to target_pop
296-
y_pred = llr_gen > t
289+
y_pred = llr_eval > t
297290

298291
# Compute power (i.e., true positive rate)
299292
power = sum(ground_truth & y_pred) / sum(ground_truth)

0 commit comments

Comments
 (0)