From 914dff3caada52d06b9865041e311a3667f7b5f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 12 Apr 2026 14:52:54 +0200 Subject: [PATCH 1/5] Add failing edge case test for NashMTL --- tests/unit/aggregation/test_nash_mtl.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index a1200d465..905197796 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -1,7 +1,7 @@ from pytest import mark from torch import Tensor from torch.testing import assert_close -from utils.tensors import ones_, randn_ +from utils.tensors import ones_, randn_, tensor_ try: from torchjd.aggregation import NashMTL @@ -19,6 +19,10 @@ def _make_aggregator(matrix: Tensor) -> NashMTL: standard_pairs = [(_make_aggregator(matrix), matrix) for matrix in nash_mtl_matrices] +edge_case_matrices = [ + tensor_([[0.0, 0.0], [0.0, 1.0]]) # This leads to a (caught) ValueError in _solve_optimization. +] +edge_case_pairs = [(_make_aggregator(matrix), matrix) for matrix in edge_case_matrices] requires_grad_pairs = [(NashMTL(n_tasks=3), ones_(3, 5, requires_grad=True))] @@ -27,8 +31,10 @@ def _make_aggregator(matrix: Tensor) -> NashMTL: @mark.filterwarnings( "ignore:Solution may be inaccurate.", "ignore:You are solving a parameterized problem that is not DPP.", + "ignore:divide by zero encountered in divide", + "ignore:invalid value encountered in matmul", ) -@mark.parametrize(["aggregator", "matrix"], standard_pairs) +@mark.parametrize(["aggregator", "matrix"], standard_pairs + edge_case_pairs) def test_expected_structure(aggregator: NashMTL, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) From 8915fc0bc0d2d3ac2c766b6ef0bf10804cf0b515 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 12 Apr 2026 14:53:19 +0200 Subject: [PATCH 2/5] Catch ValueError in _solve_optimization of NashMTL --- src/torchjd/aggregation/_nash_mtl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 06e1293df..a20be6170 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -158,9 +158,10 @@ def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray: try: self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100) - except SolverError: - # On macOS, this can happen with: Solver 'ECOS' failed. + except (SolverError, ValueError): + # On macOS, SolverError can happen with: Solver 'ECOS' failed. # No idea why. The corresponding matrix is of shape [9, 11] with rank 5. + # ValueError happens with for example matrix [[0., 0.], [0., 1.]]. # Maybe other exceptions can happen in other cases. self.alpha_param.value = self.prvs_alpha_param.value From e28ddba5885e880b8b6522564c5749bacd79dfcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 12 Apr 2026 14:55:06 +0200 Subject: [PATCH 3/5] Add changelog entry --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c9db0d9b8..bfd051c07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Fixed + +- Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example + on the matrix [[0., 0.], [0., 1.]]). + ## [0.9.0] - 2026-02-24 ### Added From 5d0a86399c2ccfcff07f31214c3dcd623bd3bd07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Apr 2026 13:44:39 +0200 Subject: [PATCH 4/5] Add more warning ignores in nashmtl test --- tests/unit/aggregation/test_nash_mtl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index 905197796..3ea655672 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -32,6 +32,8 @@ def _make_aggregator(matrix: Tensor) -> NashMTL: "ignore:Solution may be inaccurate.", "ignore:You are solving a parameterized problem that is not DPP.", "ignore:divide by zero encountered in divide", + "ignore:divide by zero encountered in true_divide", + "ignore:overflow encountered in divide", "ignore:invalid value encountered in matmul", ) @mark.parametrize(["aggregator", "matrix"], standard_pairs + edge_case_pairs) From e03c94297022e2f31c52be32b05ae3e437e52b6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Apr 2026 13:47:20 +0200 Subject: [PATCH 5/5] Add yet another warning ignore --- tests/unit/aggregation/test_nash_mtl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index 3ea655672..d82fca414 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -34,6 +34,7 @@ def _make_aggregator(matrix: Tensor) -> NashMTL: "ignore:divide by zero encountered in divide", "ignore:divide by zero encountered in true_divide", "ignore:overflow encountered in divide", + "ignore:overflow encountered in true_divide", "ignore:invalid value encountered in matmul", ) @mark.parametrize(["aggregator", "matrix"], standard_pairs + edge_case_pairs)