Skip to content

Commit 243bec9

Browse files
committed
feat: SpectralConv2d, SpectralConvTranspose2d
Introduced in ["Efficient Nonlinear Transforms for Lossy Image Compression"][Balle2018efficient] by Johannes Ballé, PCS 2018. Reparameterizes the weights to be derived from weights stored in the frequency domain. In the original paper, this is referred to as "spectral Adam" or "Sadam" due to its effect on the Adam optimizer update rule. The motivation behind representing the weights in the frequency domain is that optimizer updates/steps may now affect all frequencies to an equal amount. This improves the gradient conditioning, thus leading to faster convergence and increased stability at larger learning rates. For comparison, see the TensorFlow Compression implementations of [`SignalConv2D`] and [`RDFTParameter`]. They seem to use `SignalConv2d` in most of their provided architectures: https://github.com/search?q=repo%3Atensorflow%2Fcompression+Conv2D&type=code Furthermore, since this is a simple invertible transformation on the weights, it is trivial to convert any existing pretrained weights into this form via: ```python weight_transformed = self._to_transform_domain(weight) ``` To override `self.weight` as a property, I'm unregistering the module using `del self._parameters["weight"]` as shown in pytorch/pytorch#46886, and also [using the fact][property-descriptor-so] that `@property` [returns a descriptor object][property-descriptor-docs] so that `self.weight` "falls back" to the property. ```python def __init__(self, ...): self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight)) del self._parameters["weight"] # Unregister weight, and fallback to property. @Property def weight(self) -> Tensor: return self._from_transform_domain(self.weight_transformed) ``` [Balle2018efficient]: https://arxiv.org/abs/1802.00847 [`SignalConv2D`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/signal_conv.py#L61 [`RDFTParameter`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/parameters.py#L71 [property-descriptor-docs]: https://docs.python.org/3/howto/descriptor.html#properties [property-descriptor-so]: https://stackoverflow.com/a/17330273/365102 [`eval` mode]: https://stackoverflow.com/a/51433411/365102
1 parent a4ae2ee commit 243bec9

1 file changed

Lines changed: 62 additions & 1 deletion

File tree

compressai/layers/layers.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from typing import Any
30+
from typing import Any, Tuple
3131

3232
import torch
3333
import torch.nn as nn
@@ -43,12 +43,73 @@
4343
"ResidualBlock",
4444
"ResidualBlockUpsample",
4545
"ResidualBlockWithStride",
46+
"SpectralConv2d",
47+
"SpectralConvTranspose2d",
4648
"conv3x3",
4749
"subpel_conv3x3",
4850
"QReLU",
4951
]
5052

5153

54+
class _SpectralConvNdMixin:
55+
def __init__(self, dim: Tuple[int, ...]):
56+
self.dim = dim
57+
self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight))
58+
del self._parameters["weight"] # Unregister weight, and fallback to property.
59+
60+
@property
61+
def weight(self) -> Tensor:
62+
return self._from_transform_domain(self.weight_transformed)
63+
64+
def _to_transform_domain(self, x: Tensor) -> Tensor:
65+
return torch.fft.rfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho")
66+
67+
def _from_transform_domain(self, x: Tensor) -> Tensor:
68+
return torch.fft.irfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho")
69+
70+
71+
class SpectralConv2d(nn.Conv2d, _SpectralConvNdMixin):
72+
r"""Spectral 2D convolution.
73+
74+
Introduced in [Balle2018efficient].
75+
Reparameterizes the weights to be derived from weights stored in the
76+
frequency domain.
77+
In the original paper, this is referred to as "spectral Adam" or
78+
"Sadam" due to its effect on the Adam optimizer update rule.
79+
The motivation behind representing the weights in the frequency
80+
domain is that optimizer updates/steps may now affect all
81+
frequencies to an equal amount.
82+
This improves the gradient conditioning, thus leading to faster
83+
convergence and increased stability at larger learning rates.
84+
85+
For comparison, see the TensorFlow Compression implementations of
86+
`SignalConv2D
87+
<https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/signal_conv.py#L61>`_
88+
and
89+
`RDFTParameter
90+
<https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/parameters.py#L71>`_.
91+
92+
[Balle2018efficient]: `"Efficient Nonlinear Transforms for Lossy
93+
Image Compression" <https://arxiv.org/abs/1802.00847>`_,
94+
by Johannes Ballé, PCS 2018.
95+
"""
96+
97+
def __init__(self, *args: Any, **kwargs: Any):
98+
super().__init__(*args, **kwargs)
99+
_SpectralConvNdMixin.__init__(self, dim=(-2, -1))
100+
101+
102+
class SpectralConvTranspose2d(nn.ConvTranspose2d, _SpectralConvNdMixin):
103+
r"""Spectral 2D transposed convolution.
104+
105+
Transposed version of :class:`SpectralConv2d`.
106+
"""
107+
108+
def __init__(self, *args: Any, **kwargs: Any):
109+
super().__init__(*args, **kwargs)
110+
_SpectralConvNdMixin.__init__(self, dim=(-2, -1))
111+
112+
52113
class MaskedConv2d(nn.Conv2d):
53114
r"""Masked 2D convolution implementation, mask future "unseen" pixels.
54115
Useful for building auto-regressive network components.

0 commit comments

Comments
 (0)