Skip to content

Commit 6fbc7b8

Browse files
committed
refactor(gradvac): base on GramianWeightedAggregator with GradVacWeighting
GradVac only needs gradient norms and dot products, which are fully determined by the Gramian. This makes GradVac compatible with the autogram path. - Remove grouping parameters (group_type, encoder, shared_params) from GradVac - Export GradVacWeighting publicly
1 parent 8d1f6e7 commit 6fbc7b8

5 files changed

Lines changed: 185 additions & 277 deletions

File tree

docs/source/docs/aggregation/gradvac.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ GradVac
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.GradVacWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

src/torchjd/aggregation/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from ._dualproj import DualProj, DualProjWeighting
6767
from ._flattening import Flattening
6868
from ._graddrop import GradDrop
69-
from ._gradvac import GradVac
69+
from ._gradvac import GradVac, GradVacWeighting
7070
from ._imtl_g import IMTLG, IMTLGWeighting
7171
from ._krum import Krum, KrumWeighting
7272
from ._mean import Mean, MeanWeighting
@@ -94,6 +94,7 @@
9494
"GeneralizedWeighting",
9595
"GradDrop",
9696
"GradVac",
97+
"GradVacWeighting",
9798
"IMTLG",
9899
"IMTLGWeighting",
99100
"Krum",
Lines changed: 132 additions & 192 deletions
Original file line numberDiff line numberDiff line change
@@ -1,158 +1,124 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterable
4-
from typing import Literal, cast
3+
from typing import cast
54

65
import torch
7-
import torch.nn as nn
86
from torch import Tensor
97

10-
from torchjd._linalg import Matrix
8+
from torchjd._linalg import PSDMatrix
119

12-
from ._aggregator_bases import Aggregator
10+
from ._aggregator_bases import GramianWeightedAggregator
1311
from ._utils.non_differentiable import raise_non_differentiable_error
12+
from ._weighting_bases import Weighting
1413

1514

16-
def _all_layer_group_sizes(encoder: nn.Module) -> tuple[int, ...]:
17-
"""
18-
Block sizes per leaf submodule with parameters, matching the ``all_layer`` grouping: iterate
19-
``encoder.modules()`` and append the total number of elements in each module that has no child
20-
submodules and registers at least one parameter.
21-
"""
22-
23-
return tuple(
24-
sum(w.numel() for w in module.parameters())
25-
for module in encoder.modules()
26-
if len(list(module.children())) == 0 and next(module.parameters(), None) is not None
27-
)
28-
29-
30-
def _all_matrix_group_sizes(shared_params: Iterable[Tensor]) -> tuple[int, ...]:
31-
"""One block per tensor in ``shared_params`` order (``all_matrix`` / shared-parameter layout)."""
32-
33-
return tuple(p.numel() for p in shared_params)
34-
35-
36-
class GradVac(Aggregator):
15+
class GradVac(GramianWeightedAggregator):
3716
r"""
38-
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing Gradient Vaccine
39-
(GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in
40-
Massively Multilingual Models (ICLR 2021 Spotlight)
17+
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
18+
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
19+
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
4120
<https://openreview.net/forum?id=F1vEjWK-lH_>`_.
4221
43-
The input matrix is a Jacobian :math:`J \in \mathbb{R}^{m \times n}` whose rows are per-task
44-
gradients. For each task :math:`i` and each parameter block :math:`k`, the order in which other
45-
tasks :math:`j` are visited is drawn at random (independently for each :math:`k`); for each pair
46-
:math:`(i, j)` on block :math:`k`, the cosine correlation :math:`\phi_{ijk}` between the
22+
For each task :math:`i`, the order in which other tasks :math:`j` are visited is drawn at
23+
random. For each pair :math:`(i, j)`, the cosine similarity :math:`\phi_{ij}` between the
4724
(possibly already modified) gradient of task :math:`i` and the original gradient of task
48-
:math:`j` on that block is compared to an EMA target :math:`\hat{\phi}_{ijk}`. When
49-
:math:`\phi_{ijk} < \hat{\phi}_{ijk}`, a closed-form correction adds a scaled copy of
50-
:math:`g_j` to the block of :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with
51-
:math:`\hat{\phi}_{ijk} \leftarrow (1-\beta)\hat{\phi}_{ijk} + \beta \phi_{ijk}`. The aggregated
25+
:math:`j` is compared to an EMA target :math:`\hat{\phi}_{ij}`. When
26+
:math:`\phi_{ij} < \hat{\phi}_{ij}`, a closed-form correction adds a scaled copy of
27+
:math:`g_j` to :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with
28+
:math:`\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}`. The aggregated
5229
vector is the sum of the modified rows.
5330
5431
This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
55-
the number of tasks, parameter dimension, grouping, device, or dtype changes.
56-
57-
**Parameter granularity** is selected by ``group_type`` (default ``"whole_model"``). It defines
58-
how each task gradient row is partitioned into blocks :math:`k` so that cosines and EMA targets
59-
:math:`\hat{\phi}_{ijk}` are computed **per block** rather than only globally:
60-
61-
* ``"whole_model"``: the full row of length :math:`n` is a single block. Cosine similarity is
62-
taken between entire task gradients. Do not pass ``encoder`` or ``shared_params``.
63-
* ``"all_layer"``: one block per leaf ``nn.Module`` under ``encoder`` that holds parameters
64-
(same rule as iterating ``encoder.modules()`` and selecting leaves with parameters). Pass
65-
``encoder``; ``shared_params`` must be omitted.
66-
* ``"all_matrix"``: one block per tensor in ``shared_params``, in iteration order. That order
67-
must match how Jacobian columns are laid out for those shared parameters. Pass
68-
``shared_params``; ``encoder`` must be omitted.
69-
70-
:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign the
71-
:attr:`beta` attribute between steps to tune the EMA update.
72-
:param group_type: Granularity of parameter grouping; see **Parameter granularity** above.
73-
:param encoder: Module whose subtree defines ``all_layer`` blocks when
74-
``group_type == "all_layer"``.
75-
:param shared_params: Iterable of parameter tensors defining ``all_matrix`` block sizes and
76-
order when ``group_type == "all_matrix"``. It is materialized once at construction.
32+
the number of tasks or dtype changes.
33+
34+
:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign
35+
the :attr:`beta` attribute between steps to tune the EMA update.
7736
:param eps: Small non-negative constant added to denominators when computing cosines and the
7837
vaccine weight (default ``1e-8``); set to ``0`` to omit this stabilization. You may read or
7938
assign the :attr:`eps` attribute between steps to tune numerical behavior.
8039
8140
.. note::
82-
GradVac is not compatible with autogram: it needs full Jacobian rows and per-block inner
83-
products, not only a Gram matrix. Only the autojac path is supported.
41+
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
42+
using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
43+
you need reproducibility.
8444
8545
.. note::
86-
For each task :math:`i` and block :math:`k`, the order of other tasks :math:`j` is shuffled
87-
independently using the global PyTorch RNG (``torch.randperm``). Seed it with
88-
``torch.manual_seed`` if you need reproducibility.
46+
To apply GradVac with per-layer or per-parameter-group granularity, first aggregate the
47+
Jacobian into groups, apply GradVac per group, and sum the results. See the grouping usage
48+
example for details.
49+
"""
50+
51+
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
52+
weighting = GradVacWeighting(beta=beta, eps=eps)
53+
super().__init__(weighting)
54+
self._gradvac_weighting = weighting
55+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
56+
57+
@property
58+
def beta(self) -> float:
59+
"""EMA decay coefficient for :math:`\\hat{\\phi}` (paper default ``0.5``)."""
60+
61+
return self._gradvac_weighting.beta
62+
63+
@beta.setter
64+
def beta(self, value: float) -> None:
65+
self._gradvac_weighting.beta = value
66+
67+
@property
68+
def eps(self) -> float:
69+
"""Small non-negative constant added to denominators for numerical stability."""
70+
71+
return self._gradvac_weighting.eps
72+
73+
@eps.setter
74+
def eps(self, value: float) -> None:
75+
self._gradvac_weighting.eps = value
76+
77+
def reset(self) -> None:
78+
"""Clears EMA state so the next forward starts from zero targets."""
79+
80+
self._gradvac_weighting.reset()
81+
82+
def __repr__(self) -> str:
83+
return f"GradVac(beta={self.beta!r}, eps={self.eps!r})"
84+
85+
86+
class GradVacWeighting(Weighting[PSDMatrix]):
87+
r"""
88+
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
89+
:class:`~torchjd.aggregation.GradVac`.
90+
91+
All required quantities (gradient norms, cosine similarities, and their updates after the
92+
vaccine correction) are derived purely from the Gramian, without needing the full Jacobian.
93+
If :math:`g_i^{(\mathrm{PC})} = \sum_k c_{ik} g_k`, then:
94+
95+
.. math::
96+
97+
\|g_i^{(\mathrm{PC})}\|^2 = \mathbf{c}_i G \mathbf{c}_i^\top,\qquad
98+
g_i^{(\mathrm{PC})} \cdot g_j = \mathbf{c}_i G_{:,j}
99+
100+
where :math:`G` is the Gramian. The correction :math:`g_i^{(\mathrm{PC})} \mathrel{+}= w
101+
g_j` then becomes :math:`c_{ij} \mathrel{+}= w`, and the updated dot products follow
102+
immediately.
103+
104+
This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
105+
the number of tasks or dtype changes.
106+
107+
:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``).
108+
:param eps: Small non-negative constant added to denominators (default ``1e-8``).
89109
"""
90110

91-
def __init__(
92-
self,
93-
beta: float = 0.5,
94-
group_type: Literal["whole_model", "all_layer", "all_matrix"] = "whole_model",
95-
encoder: nn.Module | None = None,
96-
shared_params: Iterable[Tensor] | None = None,
97-
eps: float = 1e-8,
98-
) -> None:
111+
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
99112
super().__init__()
100113
if not (0.0 <= beta <= 1.0):
101114
raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.")
102-
params_tuple: tuple[Tensor, ...] = ()
103-
fixed_block_sizes: tuple[int, ...] | None
104-
if group_type == "whole_model":
105-
if encoder is not None:
106-
raise ValueError(
107-
'Parameter `encoder` must be None when `group_type == "whole_model"`.'
108-
)
109-
if shared_params is not None:
110-
raise ValueError(
111-
'Parameter `shared_params` must be None when `group_type == "whole_model"`.'
112-
)
113-
fixed_block_sizes = None
114-
elif group_type == "all_layer":
115-
if encoder is None:
116-
raise ValueError(
117-
'Parameter `encoder` is required when `group_type == "all_layer"`.'
118-
)
119-
if shared_params is not None:
120-
raise ValueError(
121-
'Parameter `shared_params` must be None when `group_type == "all_layer"`.'
122-
)
123-
fixed_block_sizes = _all_layer_group_sizes(encoder)
124-
if sum(fixed_block_sizes) == 0:
125-
raise ValueError("Parameter `encoder` has no parameters in any leaf module.")
126-
else:
127-
if shared_params is None:
128-
raise ValueError(
129-
'Parameter `shared_params` is required when `group_type == "all_matrix"`.'
130-
)
131-
if encoder is not None:
132-
raise ValueError(
133-
'Parameter `encoder` must be None when `group_type == "all_matrix"`.'
134-
)
135-
params_tuple = tuple(shared_params)
136-
if len(params_tuple) == 0:
137-
raise ValueError(
138-
'Parameter `shared_params` must be non-empty when `group_type == "all_matrix"`.'
139-
)
140-
fixed_block_sizes = _all_matrix_group_sizes(params_tuple)
141-
142115
if eps < 0.0:
143116
raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.")
144117

145118
self._beta = beta
146-
self._group_type = group_type
147-
self._encoder = encoder
148-
self._shared_params_len = len(params_tuple)
149-
self._fixed_block_sizes = fixed_block_sizes
150119
self._eps = eps
151-
152120
self._phi_t: Tensor | None = None
153-
self._state_key: tuple[int, int, tuple[int, ...], torch.device, torch.dtype] | None = None
154-
155-
self.register_full_backward_pre_hook(raise_non_differentiable_error)
121+
self._state_key: tuple[int, torch.dtype] | None = None
156122

157123
@property
158124
def beta(self) -> float:
@@ -184,82 +150,56 @@ def reset(self) -> None:
184150
self._phi_t = None
185151
self._state_key = None
186152

187-
def __repr__(self) -> str:
188-
enc = "None" if self._encoder is None else f"{self._encoder.__class__.__name__}(...)"
189-
sp = "None" if self._group_type != "all_matrix" else f"n_params={self._shared_params_len}"
190-
return (
191-
f"{self.__class__.__name__}(beta={self._beta!r}, group_type={self._group_type!r}, "
192-
f"encoder={enc}, shared_params={sp}, eps={self._eps!r})"
193-
)
194-
195-
def _resolve_segment_sizes(self, n: int) -> tuple[int, ...]:
196-
if self._group_type == "whole_model":
197-
return (n,)
198-
sizes = cast(tuple[int, ...], self._fixed_block_sizes)
199-
if sum(sizes) != n:
200-
raise ValueError(
201-
"The Jacobian width `n` must equal the sum of block sizes implied by "
202-
f"`encoder` or `shared_params` for this `group_type`. Found n={n}, "
203-
f"sum(block_sizes)={sum(sizes)}.",
204-
)
205-
return sizes
206-
207-
def _ensure_state(
208-
self,
209-
m: int,
210-
n: int,
211-
sizes: tuple[int, ...],
212-
device: torch.device,
213-
dtype: torch.dtype,
214-
) -> None:
215-
key = (m, n, sizes, device, dtype)
216-
num_groups = len(sizes)
217-
if self._state_key != key or self._phi_t is None:
218-
self._phi_t = torch.zeros(m, m, num_groups, device=device, dtype=dtype)
219-
self._state_key = key
153+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
154+
device = gramian.device
155+
dtype = gramian.dtype
156+
cpu = torch.device("cpu")
220157

221-
def forward(self, matrix: Matrix, /) -> Tensor:
222-
grads = matrix
223-
m, n = grads.shape
224-
if m == 0 or n == 0:
225-
return torch.zeros(n, dtype=grads.dtype, device=grads.device)
158+
G = cast(PSDMatrix, gramian.to(device=cpu))
159+
m = G.shape[0]
226160

227-
sizes = self._resolve_segment_sizes(n)
228-
device = grads.device
229-
dtype = grads.dtype
230-
self._ensure_state(m, n, sizes, device, dtype)
161+
self._ensure_state(m, dtype)
231162
phi_t = cast(Tensor, self._phi_t)
232-
beta = self.beta
233-
eps = self.eps
234163

235-
pc_grads = grads.clone()
236-
offsets = [0]
237-
for s in sizes:
238-
offsets.append(offsets[-1] + s)
164+
beta = self._beta
165+
eps = self._eps
166+
167+
# C[i, :] holds coefficients such that g_i^PC = sum_k C[i,k] * g_k (original gradients).
168+
# Initially each modified gradient equals the original, so C = I.
169+
C = torch.eye(m, device=cpu, dtype=dtype)
239170

240171
for i in range(m):
172+
# Dot products of g_i^PC with every original g_j, shape (m,).
173+
cG = C[i] @ G
174+
241175
others = [j for j in range(m) if j != i]
242-
for k in range(len(sizes)):
243-
perm = torch.randperm(len(others))
244-
shuffled_js = [others[idx] for idx in perm.tolist()]
245-
beg, end = offsets[k], offsets[k + 1]
246-
for j in shuffled_js:
247-
slice_i = pc_grads[i, beg:end]
248-
slice_j = grads[j, beg:end]
249-
250-
norm_i = slice_i.norm()
251-
norm_j = slice_j.norm()
252-
denom = norm_i * norm_j + eps
253-
phi_ijk = slice_i.dot(slice_j) / denom
254-
255-
phi_hat = phi_t[i, j, k]
256-
if phi_ijk < phi_hat:
257-
sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt()
258-
sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt()
259-
denom_w = norm_j * sqrt_1_hat2 + eps
260-
w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w
261-
pc_grads[i, beg:end] = slice_i + slice_j * w
262-
263-
phi_t[i, j, k] = (1.0 - beta) * phi_hat + beta * phi_ijk
264-
265-
return pc_grads.sum(dim=0)
176+
perm = torch.randperm(len(others))
177+
shuffled_js = [others[idx] for idx in perm.tolist()]
178+
179+
for j in shuffled_js:
180+
dot_ij = cG[j]
181+
norm_i_sq = (cG * C[i]).sum()
182+
norm_i = norm_i_sq.clamp(min=0.0).sqrt()
183+
norm_j = G[j, j].clamp(min=0.0).sqrt()
184+
denom = norm_i * norm_j + eps
185+
phi_ijk = dot_ij / denom
186+
187+
phi_hat = phi_t[i, j]
188+
if phi_ijk < phi_hat:
189+
sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt()
190+
sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt()
191+
denom_w = norm_j * sqrt_1_hat2 + eps
192+
w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w
193+
C[i, j] = C[i, j] + w
194+
cG = cG + w * G[j]
195+
196+
phi_t[i, j] = (1.0 - beta) * phi_hat + beta * phi_ijk
197+
198+
weights = C.sum(dim=0)
199+
return weights.to(device)
200+
201+
def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
202+
key = (m, dtype)
203+
if self._state_key != key or self._phi_t is None:
204+
self._phi_t = torch.zeros(m, m, dtype=dtype)
205+
self._state_key = key

0 commit comments

Comments
 (0)