66from scipy .stats import norm
77from sklearn import metrics
88
9- from src .config import get_cur_dir , set_seed
9+ from src .config import get_cur_dir , safe_assert , set_seed
1010from src .defense import noisy_bn
1111
1212
@@ -39,21 +39,14 @@ def mia_vs_bn(exp, config) -> dict:
3939 f"{ cur_dir } /{ config ['bns_path' ]} /pool/bn_{ exp } _sample{ sample } .bif"
4040 )
4141
42- bn_theta_hat_ie = gum .LazyPropagation (bn_theta_hat )
43- bn_theta_ie = gum .LazyPropagation (bn_theta )
44-
45- # ... retrieve rpop, ...
46- rpop = gpop [gpop [f"in-rpop-{ sample } " ]].iloc [:, : len (bn_theta .nodes ())]
47-
4842 # try:
4943
5044 # ... and perform membership inference on gpop
5145 power_vec , auc = run_mia (
52- bn_theta_hat_ie ,
53- bn_theta_ie ,
54- rpop ,
46+ bn_theta_hat ,
47+ bn_theta ,
5548 gpop ,
56- gpop [ f"in-pool- { sample } " ] ,
49+ sample ,
5750 eval (config ["error" ]),
5851 )
5952 power_res [f"power_BN_sample{ sample } " ] = power_vec
@@ -108,21 +101,14 @@ def mia_vs_cn(exp, config) -> pd.DataFrame:
108101 f"{ cur_dir } /{ config ['bns_path' ]} /rpop/bn_{ exp } _sample{ sample } .bif"
109102 )
110103
111- bn_theta_hat_ie = gum .LazyPropagation (bn_theta_hat )
112- bn_theta_ie = gum .LazyPropagation (bn_theta )
113-
114- # ... retrieve rpop, ...
115- rpop = gpop [gpop [f"in-rpop-{ sample } " ]].iloc [:, : len (bn_theta .nodes ())]
116-
117104 # try:
118105
119106 # ... and perform membership inference on gpop
120107 power_vec , auc = run_mia (
121- bn_theta_hat_ie ,
122- bn_theta_ie ,
123- rpop ,
108+ bn_theta_hat ,
109+ bn_theta ,
124110 gpop ,
125- gpop [ f"in-pool- { sample } " ] ,
111+ sample ,
126112 eval (config ["error" ]),
127113 )
128114 power_res [f"power_CN_sample{ sample } " ] = power_vec
@@ -209,9 +195,6 @@ def find_epsilon(exp, config) -> dict:
209195 f"{ cur_dir } /{ config ['bns_path' ]} /pool/bn_{ exp } _sample{ sample } .bif"
210196 )
211197
212- # ... retrieve rpop, ...
213- rpop = gpop [gpop [f"in-rpop-{ sample } " ]].iloc [:, : len (bn_theta .nodes ())]
214-
215198 # ... get CN AUC, ...
216199 auc_cn = auc_res .loc [auc_res ["sample" ] == sample , "auc_cn" ].values [0 ]
217200
@@ -226,16 +209,12 @@ def find_epsilon(exp, config) -> dict:
226209 scale = (2 * bn_theta_hat .size ()) / (pool_ss * eps )
227210 bn_noisy = noisy_bn (bn_theta_hat , scale )
228211
229- bn_noisy_ie = gum .LazyPropagation (bn_noisy )
230- bn_theta_ie = gum .LazyPropagation (bn_theta )
231-
232212 # Perform membership inference on gpop
233213 power_vec , auc = run_mia (
234- bn_noisy_ie ,
235- bn_theta_ie ,
236- rpop ,
214+ bn_noisy ,
215+ bn_theta ,
237216 gpop ,
238- gpop [ f"in-pool- { sample } " ] ,
217+ sample ,
239218 eval (config ["error" ]),
240219 )
241220
@@ -261,39 +240,53 @@ def find_epsilon(exp, config) -> dict:
261240
262241
263242# MIA: membership inference attack
264- def run_mia (model , baseline , rpop , gpop , ground_truth , error_vec ):
243+ def run_mia (model , baseline , gpop , sample , error_vec ):
244+
245+ # Create objects for inference
246+ model_ie = gum .LazyPropagation (model )
247+ baseline_ie = gum .LazyPropagation (baseline )
265248
266- # Compute llr(x) on reference and general populations
267- llr_ref = (
268- rpop .apply (lambda x : get_llr (x .to_dict (), baseline , model ), axis = 1 )
249+ # Retrieve rpop and evaluation set (i.e. the rpop complement in gpop)
250+ rpop = gpop [gpop [f"in-rpop-{ sample } " ]].iloc [:, : len (model .nodes ())]
251+ eval_pop = gpop [~ gpop [f"in-rpop-{ sample } " ]].iloc [:, : len (model .nodes ())]
252+ ground_truth = gpop [~ gpop [f"in-rpop-{ sample } " ]].loc [:, f"in-pool-{ sample } " ]
253+
254+ # Compute llr(x)'s
255+ llr_rpop = (
256+ rpop .apply (lambda x : get_llr (x .to_dict (), baseline_ie , model_ie ), axis = 1 )
269257 .dropna ()
270258 .sort_values ()
271259 )
272- llr_gen = gpop [[ * rpop . columns ]] .apply (
273- lambda x : get_llr (x .to_dict (), baseline , model ), axis = 1
260+ llr_eval = eval_pop .apply (
261+ lambda x : get_llr (x .to_dict (), baseline_ie , model_ie ), axis = 1
274262 )
275263
276264 power_vec = []
277265
278266 # Get the power for each error
279267 for error in error_vec :
280- power = get_power (llr_ref , llr_gen , ground_truth , error )
268+ power = get_power (llr_rpop , llr_eval , ground_truth , error )
281269 power_vec .append (power )
282270
283271 # Compute and store AUC
284272 auc = metrics .auc (error_vec , power_vec )
285273
274+ # Debug
275+ safe_assert (len (rpop ) + len (eval_pop ) == len (gpop ))
276+ safe_assert (len (eval_pop ) == len (ground_truth ))
277+ safe_assert (len (power_vec ) == len (error_vec ))
278+
286279 return power_vec , auc
287280
288281
289282# Get the attack power related to a fixed error
290- def get_power (llr_ref , llr_gen , ground_truth , error ) -> float :
283+ def get_power (llr_ref , llr_eval , ground_truth , error ) -> float :
291284
292285 # Compute the threshold
293286 t = np .quantile (llr_ref , 1 - error ).item ()
294287
295288 # Test: L(x) > t => reject H_0 => assign `x` to target_pop
296- y_pred = llr_gen > t
289+ y_pred = llr_eval > t
297290
298291 # Compute power (i.e., true positive rate)
299292 power = sum (ground_truth & y_pred ) / sum (ground_truth )
0 commit comments