Skip to content

Commit c4ed86f

Browse files
committed
docs(aggregation): add grouping usage example and fix GradVac note
Add a Grouping example page covering all four strategies from the GradVac paper (whole_model, enc_dec, all_layer, all_matrix), with a runnable code block for each. Update the GradVac docstring note to link to the new page instead of the previous placeholder text. Fix trailing whitespace in CHANGELOG.md. Made-with: Cursor
1 parent 1034bbf commit c4ed86f

File tree

4 files changed

+178
-4
lines changed

4 files changed

+178
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ changelog does not include internal changes that do not affect the user.
1010

1111
### Added
1212

13-
- Added `GradVac` and `GradVacWeighting` from
13+
- Added `GradVac` and `GradVacWeighting` from
1414
[Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874).
1515
- Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example
1616
on the matrix [[0., 0.], [0., 1.]]).

docs/source/examples/grouping.rst

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
Grouping
2+
========
3+
4+
When applying a conflict-resolving aggregator such as :class:`~torchjd.aggregation.GradVac` in
5+
multi-task learning, the cosine similarities between task gradients can be computed at different
6+
granularities. The GradVac paper introduces four strategies, each partitioning the shared
7+
parameter vector differently:
8+
9+
1. **Whole Model** (default) — one group covering all shared parameters.
10+
2. **Encoder-Decoder** — one group per top-level sub-network (e.g. encoder and decoder separately).
11+
3. **All Layers** — one group per leaf module of the encoder.
12+
4. **All Matrices** — one group per individual parameter tensor.
13+
14+
In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group
15+
after :func:`~torchjd.autojac.mtl_backward`, with a dedicated aggregator instance per group.
16+
For stateful aggregators such as :class:`~torchjd.aggregation.GradVac`, each instance
17+
independently maintains its own EMA state :math:`\hat{\phi}`, matching the per-block targets from
18+
the original paper.
19+
20+
.. note::
21+
The grouping is orthogonal to the choice of
22+
:func:`~torchjd.autojac.backward` vs :func:`~torchjd.autojac.mtl_backward`. Those functions
23+
determine *which* parameters receive Jacobians; grouping then determines *how* those Jacobians
24+
are partitioned for aggregation. Calling :func:`~torchjd.autojac.jac_to_grad` once on all shared
25+
parameters corresponds to the Whole Model strategy. Splitting those parameters into
26+
sub-networks and calling :func:`~torchjd.autojac.jac_to_grad` separately on each — with a
27+
dedicated aggregator per sub-network — gives an arbitrary custom grouping, such as the
28+
Encoder-Decoder strategy described in the GradVac paper for encoder-decoder architectures.
29+
30+
.. note::
31+
The examples below use :class:`~torchjd.aggregation.GradVac`, but the same pattern applies to
32+
any aggregator.
33+
34+
1. Whole Model
35+
--------------
36+
37+
A single :class:`~torchjd.aggregation.GradVac` instance aggregates all shared parameters
38+
together. Cosine similarities are computed between the full task gradient vectors.
39+
40+
.. testcode::
41+
:emphasize-lines: 14, 19
42+
43+
import torch
44+
from torch.nn import Linear, MSELoss, ReLU, Sequential
45+
from torch.optim import SGD
46+
47+
from torchjd.aggregation import GradVac
48+
from torchjd.autojac import jac_to_grad, mtl_backward
49+
50+
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
51+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
52+
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
53+
loss_fn = MSELoss()
54+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
55+
56+
gradvac = GradVac()
57+
58+
for x, y1, y2 in zip(inputs, t1, t2):
59+
features = encoder(x)
60+
mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
61+
jac_to_grad(encoder.parameters(), gradvac)
62+
optimizer.step()
63+
optimizer.zero_grad()
64+
65+
2. Encoder-Decoder
66+
------------------
67+
68+
One :class:`~torchjd.aggregation.GradVac` instance per top-level sub-network. Here the model
69+
is split into an encoder and a decoder; cosine similarities are computed separately within each.
70+
Passing ``features=dec_out`` to :func:`~torchjd.autojac.mtl_backward` causes both sub-networks
71+
to receive Jacobians, which are then aggregated independently.
72+
73+
.. testcode::
74+
:emphasize-lines: 8-9, 15-16, 22-23
75+
76+
import torch
77+
from torch.nn import Linear, MSELoss, ReLU, Sequential
78+
from torch.optim import SGD
79+
80+
from torchjd.aggregation import GradVac
81+
from torchjd.autojac import jac_to_grad, mtl_backward
82+
83+
encoder = Sequential(Linear(10, 5), ReLU())
84+
decoder = Sequential(Linear(5, 3), ReLU())
85+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
86+
optimizer = SGD([*encoder.parameters(), *decoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
87+
loss_fn = MSELoss()
88+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
89+
90+
encoder_gradvac = GradVac()
91+
decoder_gradvac = GradVac()
92+
93+
for x, y1, y2 in zip(inputs, t1, t2):
94+
enc_out = encoder(x)
95+
dec_out = decoder(enc_out)
96+
mtl_backward([loss_fn(task1_head(dec_out), y1), loss_fn(task2_head(dec_out), y2)], features=dec_out)
97+
jac_to_grad(encoder.parameters(), encoder_gradvac)
98+
jac_to_grad(decoder.parameters(), decoder_gradvac)
99+
optimizer.step()
100+
optimizer.zero_grad()
101+
102+
3. All Layers
103+
-------------
104+
105+
One :class:`~torchjd.aggregation.GradVac` instance per leaf module. Cosine similarities are
106+
computed between the per-layer blocks of the task gradients.
107+
108+
.. testcode::
109+
:emphasize-lines: 14-15, 20-21
110+
111+
import torch
112+
from torch.nn import Linear, MSELoss, ReLU, Sequential
113+
from torch.optim import SGD
114+
115+
from torchjd.aggregation import GradVac
116+
from torchjd.autojac import jac_to_grad, mtl_backward
117+
118+
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
119+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
120+
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
121+
loss_fn = MSELoss()
122+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
123+
124+
leaf_layers = [m for m in encoder.modules() if not list(m.children()) and list(m.parameters())]
125+
gradvacs = [GradVac() for _ in leaf_layers]
126+
127+
for x, y1, y2 in zip(inputs, t1, t2):
128+
features = encoder(x)
129+
mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
130+
for layer, gradvac in zip(leaf_layers, gradvacs):
131+
jac_to_grad(layer.parameters(), gradvac)
132+
optimizer.step()
133+
optimizer.zero_grad()
134+
135+
4. All Matrices
136+
---------------
137+
138+
One :class:`~torchjd.aggregation.GradVac` instance per individual parameter tensor. Cosine
139+
similarities are computed between the per-tensor blocks of the task gradients (e.g. weights and
140+
biases of each layer are treated as separate groups).
141+
142+
.. testcode::
143+
:emphasize-lines: 14-15, 20-21
144+
145+
import torch
146+
from torch.nn import Linear, MSELoss, ReLU, Sequential
147+
from torch.optim import SGD
148+
149+
from torchjd.aggregation import GradVac
150+
from torchjd.autojac import jac_to_grad, mtl_backward
151+
152+
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
153+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
154+
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
155+
loss_fn = MSELoss()
156+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
157+
158+
shared_params = list(encoder.parameters())
159+
gradvacs = [GradVac() for _ in shared_params]
160+
161+
for x, y1, y2 in zip(inputs, t1, t2):
162+
features = encoder(x)
163+
mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features)
164+
for param, gradvac in zip(shared_params, gradvacs):
165+
jac_to_grad([param], gradvac)
166+
optimizer.step()
167+
optimizer.zero_grad()

docs/source/examples/index.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ This section contains some usage examples for TorchJD.
2929
- :doc:`PyTorch Lightning Integration <lightning_integration>` showcases how to combine
3030
TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task
3131
``LightningModule`` optimized by Jacobian descent.
32+
- :doc:`Grouping <grouping>` shows how to apply an aggregator independently per parameter group
33+
(e.g. per layer), so that conflict resolution happens at a finer granularity than the full
34+
shared parameter vector.
3235
- :doc:`Automatic Mixed Precision <amp>` shows how to combine mixed precision training with TorchJD.
3336

3437
.. toctree::
@@ -43,3 +46,4 @@ This section contains some usage examples for TorchJD.
4346
monitoring.rst
4447
lightning_integration.rst
4548
amp.rst
49+
grouping.rst

src/torchjd/aggregation/_gradvac.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,12 @@ class GradVac(GramianWeightedAggregator):
4343
you need reproducibility.
4444
4545
.. note::
46-
To apply GradVac with per-layer or per-parameter-group granularity, first aggregate the
47-
Jacobian into groups, apply GradVac per group, and sum the results. See the grouping usage
48-
example for details.
46+
To apply GradVac with per-layer or per-parameter-group granularity, create a separate
47+
:class:`GradVac` instance for each group and call
48+
:func:`~torchjd.autojac.jac_to_grad` once per group after
49+
:func:`~torchjd.autojac.mtl_backward`. Each instance maintains its own EMA state,
50+
matching the per-block targets :math:`\hat{\phi}_{ijk}` from the original paper. See
51+
the :doc:`Grouping </examples/grouping>` example for details.
4952
"""
5053

5154
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:

0 commit comments

Comments
 (0)