Skip to content

Image MINV: modality abstraction and PLGMI image path cleanup#404

Open
fazelehh wants to merge 3 commits intomainfrom
minv-image-cleanups
Open

Image MINV: modality abstraction and PLGMI image path cleanup#404
fazelehh wants to merge 3 commits intomainfrom
minv-image-cleanups

Conversation

@fazelehh
Copy link
Copy Markdown
Collaborator

@fazelehh fazelehh commented Apr 1, 2026

Description

Clean up the image-based PLGMI model inversion attack and introduce a modality abstraction layer. This PR contains only the reliable, tested image/CelebA MINV code — tabular extensions are deferred to a separate PR.

  • Add modality_extension.py abstract base class for modality-specific extensions
  • Refactor image_extension.py from 194 to ~35 lines using the new abstraction
  • Clean plgmi.py to image-only path: remove all tabular/dataframe branches, CTGAN loading, and leftover WIP/debug methods (optimize_z_grad_original, optimize_z_grad2, load_reference_model, etc.)
  • Remove hardcoded checkpoint paths (gen_3000.pth/dis_3000.pth) from celebA_plgmi_handler.py
  • Fix audit.yaml data_modality field (was commented out)
  • Add optional discriminator support in gan_handler.py
  • Refactor max_margin_loss in gan_losses.py
  • Add pytorch_tabular to allowed model_type values in schemas.py
  • Apply ruff linting fixes across all changed files

Resolved Issues

  • fixes #Issue

How Has This Been Tested?

Manually tested end-to-end on the CelebA example (examples/minv/celebA/) with a trained ResNet152 target model.

Related Pull Requests

@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

super().__init__(handler)
logger.info("Image extension initialized.")

def set_augment_strength(self, augment_strength: str) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does all attacks still work with this function gone from the image extension? From what I can see, RaMIA.py, line 75-76, will probably fail since "set_augment_strenght()" is removed but still called in ramia.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The refactor looks clean and great! But the functions removed, are they outdated or do they exists somewhere else? Are we dropping them

ops_used.append(ops)
out = torch.stack(ys, dim=0) if stack else ys
return out, ops_used
n_aug = n_aug // len(data) #dummy to pass ruff
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it preferable to "def augmentation(self:Self, data:Tensor, n_aug:int) -> Tensor: # noqa: ARG002" ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not particularly important for the approval of the PR.


logger.info("Training the GAN")
# TODO: Change this input structure to just pass the attack class
self.handler.train_gan(pseudo_loader = self.pseudo_loader,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoint_interval is removed from handler.train_gan() function call but is still defined in the train_gan() function, celeba_plgmi_handler.py line 96 and the audit.yaml. It is also still used in the train_gan() function; if i % checkpoint_interval == 0 and i > 0:

"""Run the attack."""
logger.info("Running the PLG-MI attack")
# Define image metrics class
# if getattr(self, "reference_model", None) is None: # noqa: ERA001
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand these comment, is it for a guide of future implementation?

Args:
----
y (torch.tensor): The class labels.
lr (float): The learning rate for optimization.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update arguments in the documentation to match the arguments of the function


def max_margin_loss(out: torch.Tensor, iden: torch.Tensor) -> torch.Tensor:
"""Compute the max margin loss.
# def max_margin_loss(out: torch.Tensor, iden: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the commented version of the max_margin_loss(...)?

@@ -58,37 +58,38 @@
"source": [
Copy link
Copy Markdown
Collaborator

@henrikfo henrikfo Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #28.    #y = torch.tensor([0, 1, 2, 3]).to(device) # fixed labels

The duplication and commented "y" is not needed


Reply via ReviewNB

Copy link
Copy Markdown
Collaborator

@henrikfo henrikfo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few small comments and feedbacks! Everything but these things looks great to me! I might have missed some context for them so let me know if they are incorrect, otherwise, with these changes we can merge!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants