Skip to content

add low- and high-level Concept-Based Memory Reasoner (CMR) implementation#12

Open
daviddebot wants to merge 2 commits intopyc-team:devfrom
daviddebot:cmr
Open

add low- and high-level Concept-Based Memory Reasoner (CMR) implementation#12
daviddebot wants to merge 2 commits intopyc-team:devfrom
daviddebot:cmr

Conversation

@daviddebot
Copy link
Copy Markdown
Collaborator

This PR adds CMR across both low-level and high-level model API, with documentation updates and unit tests.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 3, 2026

Codecov Report

❌ Patch coverage is 88.18898% with 15 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
torch_concepts/nn/modules/high/models/cmr.py 71.11% 13 Missing ⚠️
torch_concepts/nn/modules/loss.py 91.30% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

'task_target': target[:, task_indices],
}

def shared_step(self, batch, step):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to avoid overriding the lightning shared step? it is OK to instead override the model forward since each model can have their own.

I see the problem is the task predictor which need to be done twice. Is is possible to play with the inference query in the forward?

return c_loss * self.concept_weight + t_loss * self.task_weight


class CMRLoss(nn.Module):
Copy link
Copy Markdown
Collaborator

@gdefe gdefe Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great not to have a dedicated loss for every model. It would be great to decompose the CMR loss into modular pieces. I would expect you just need to create an additional piece for the recursion part. To combine the traditional loss with this additional loss, see the way we now enable losses to be summed in the example
pytorch_concepts/examples/utilization/2.2_model/13_composite_loss.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants