Image MINV: modality abstraction and PLGMI image path cleanup#404
Image MINV: modality abstraction and PLGMI image path cleanup#404
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
| super().__init__(handler) | ||
| logger.info("Image extension initialized.") | ||
|
|
||
| def set_augment_strength(self, augment_strength: str) -> None: |
There was a problem hiding this comment.
Does all attacks still work with this function gone from the image extension? From what I can see, RaMIA.py, line 75-76, will probably fail since "set_augment_strenght()" is removed but still called in ramia.
There was a problem hiding this comment.
The refactor looks clean and great! But the functions removed, are they outdated or do they exists somewhere else? Are we dropping them
| ops_used.append(ops) | ||
| out = torch.stack(ys, dim=0) if stack else ys | ||
| return out, ops_used | ||
| n_aug = n_aug // len(data) #dummy to pass ruff |
There was a problem hiding this comment.
Is it preferable to "def augmentation(self:Self, data:Tensor, n_aug:int) -> Tensor: # noqa: ARG002" ?
There was a problem hiding this comment.
Not particularly important for the approval of the PR.
|
|
||
| logger.info("Training the GAN") | ||
| # TODO: Change this input structure to just pass the attack class | ||
| self.handler.train_gan(pseudo_loader = self.pseudo_loader, |
There was a problem hiding this comment.
checkpoint_interval is removed from handler.train_gan() function call but is still defined in the train_gan() function, celeba_plgmi_handler.py line 96 and the audit.yaml. It is also still used in the train_gan() function; if i % checkpoint_interval == 0 and i > 0:
| """Run the attack.""" | ||
| logger.info("Running the PLG-MI attack") | ||
| # Define image metrics class | ||
| # if getattr(self, "reference_model", None) is None: # noqa: ERA001 |
There was a problem hiding this comment.
I don't really understand these comment, is it for a guide of future implementation?
| Args: | ||
| ---- | ||
| y (torch.tensor): The class labels. | ||
| lr (float): The learning rate for optimization. |
There was a problem hiding this comment.
Update arguments in the documentation to match the arguments of the function
|
|
||
| def max_margin_loss(out: torch.Tensor, iden: torch.Tensor) -> torch.Tensor: | ||
| """Compute the max margin loss. | ||
| # def max_margin_loss(out: torch.Tensor, iden: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Do we need the commented version of the max_margin_loss(...)?
| @@ -58,37 +58,38 @@ | |||
| "source": [ | |||
There was a problem hiding this comment.
Line #28. #y = torch.tensor([0, 1, 2, 3]).to(device) # fixed labels
The duplication and commented "y" is not needed
Reply via ReviewNB
henrikfo
left a comment
There was a problem hiding this comment.
Added a few small comments and feedbacks! Everything but these things looks great to me! I might have missed some context for them so let me know if they are incorrect, otherwise, with these changes we can merge!
Description
Clean up the image-based PLGMI model inversion attack and introduce a modality abstraction layer. This PR contains only the reliable, tested image/CelebA MINV code — tabular extensions are deferred to a separate PR.
modality_extension.pyabstract base class for modality-specific extensionsimage_extension.pyfrom 194 to ~35 lines using the new abstractionplgmi.pyto image-only path: remove all tabular/dataframe branches, CTGAN loading, and leftover WIP/debug methods (optimize_z_grad_original,optimize_z_grad2,load_reference_model, etc.)gen_3000.pth/dis_3000.pth) fromcelebA_plgmi_handler.pyaudit.yamldata_modalityfield (was commented out)gan_handler.pymax_margin_lossingan_losses.pypytorch_tabularto allowedmodel_typevalues inschemas.pyResolved Issues
How Has This Been Tested?
Manually tested end-to-end on the CelebA example (
examples/minv/celebA/) with a trained ResNet152 target model.Related Pull Requests