Skip to content
5 changes: 3 additions & 2 deletions escnn/kernels/steerable_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from escnn.group import Group
from escnn.group import IrreducibleRepresentation
from escnn.group import Representation
from escnn.utils import unique_ever_seen

import torch

Expand Down Expand Up @@ -238,9 +239,9 @@ def __init__(self,
js = set()

# loop over all input irreps
for i_irrep_id in set(in_repr.irreps):
for i_irrep_id in unique_ever_seen(in_repr.irreps):
# loop over all output irreps
for o_irrep_id in set(out_repr.irreps):
for o_irrep_id in unique_ever_seen(out_repr.irreps):
try:
# retrieve the irrep intertwiner basis
intertwiner_basis = irreps_basis._generator(basis, i_irrep_id, o_irrep_id, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion escnn/kernels/wignereckart_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .steerable_filters_basis import SteerableFiltersBasis

from escnn.group import *
from escnn.utils import unique_ever_seen

import torch

Expand Down Expand Up @@ -302,7 +303,7 @@ def __init__(self,
_js_restriction = defaultdict(list)

# for each harmonic j' to consider
for _j in set(_j for _j, _ in basis.js):
for _j in unique_ever_seen(_j for _j, _ in basis.js):
if basis.multiplicity(_j) == 0:
continue

Expand Down
9 changes: 5 additions & 4 deletions escnn/nn/modules/basismanager/basisexpansion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from escnn.kernels import KernelBasis, EmptyBasisException
from escnn.group import Representation
from escnn.nn.modules import utils
from escnn.utils import unique_ever_seen

from .basismanager import BasisManager
from .basisexpansion_singleblock import block_basisexpansion
Expand Down Expand Up @@ -62,8 +63,8 @@ def __init__(self,

# iterate through all different pairs of input/output representations
# and, for each of them, build a basis
for i_repr in set(in_reprs):
for o_repr in set(out_reprs):
for i_repr in unique_ever_seen(in_reprs):
for o_repr in unique_ever_seen(out_reprs):
reprs_names = (i_repr.name, o_repr.name)
try:

Expand Down Expand Up @@ -130,8 +131,8 @@ def __init__(self,
out_indices = out_indices.reshape(-1)

# register the indices tensors and the bases tensors as parameters of this module
self.register_buffer('in_indices_{}'.format(self._escape_pair(io_pair)), in_indices)
self.register_buffer('out_indices_{}'.format(self._escape_pair(io_pair)), out_indices)
self.register_buffer('in_indices_{}'.format(self._escape_pair(io_pair)), in_indices, persistent=False)
self.register_buffer('out_indices_{}'.format(self._escape_pair(io_pair)), out_indices, persistent=False)

# number of occurrences of the input/output pair `io_pair`
n_pairs = self._in_count[io_pair[0]] * self._out_count[io_pair[1]]
Expand Down
5 changes: 3 additions & 2 deletions escnn/nn/modules/basismanager/basissampler_blocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

from escnn.group import Representation
from escnn.kernels import KernelBasis, EmptyBasisException
from escnn.utils import unique_ever_seen


from escnn.nn.modules.basismanager import retrieve_indices
Expand Down Expand Up @@ -62,8 +63,8 @@ def __init__(self,

# iterate through all different pairs of input/output representations
# and, for each of them, build a basis
for i_repr in set(in_reprs):
for o_repr in set(out_reprs):
for i_repr in unique_ever_seen(in_reprs):
for o_repr in unique_ever_seen(out_reprs):
reprs_names = (i_repr.name, o_repr.name)
try:

Expand Down
5 changes: 3 additions & 2 deletions escnn/nn/modules/batchnormalization/gnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from escnn.gspaces import *
from escnn.nn import FieldType
from escnn.nn import GeometricTensor
from escnn.utils import unique_ever_seen

from ..equivariant_module import EquivariantModule

Expand Down Expand Up @@ -130,8 +131,8 @@ def __init__(self,

# store for each field type the sizes and the indices of all its irreps, grouped by their size
self._irreps_sizes = {}
for r in self.in_type._unique_representations:

for r in unique_ever_seen(self.in_type.representations):
p = 0
irreps = defaultdict(lambda: [])
trivials = []
Expand Down
20 changes: 18 additions & 2 deletions escnn/nn/modules/batchnormalization/iid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from escnn.gspaces import *
from escnn.nn import FieldType
from escnn.nn import GeometricTensor
from escnn.utils import unique_ever_seen

from ..equivariant_module import EquivariantModule

Expand Down Expand Up @@ -103,8 +104,23 @@ def __init__(self,

self._has_trivial = {}

# for each different representation in the input type
for r in self.in_type._unique_representations:
# for each different representation in the input type.

# It's important to ensure that we iterate through the representations
# in the same order every time the program runs. This order becomes
# the order that the various batch norm parameters are passed to the
# optimizer, and if that order changes between runs, then it becomes
# impossible to resume training from checkpoints [1].

# Practically, this means that we can't use a set to (more succinctly)
# eliminate duplicate representations. Set iteration order is not only
# arbitrary, but also non-deterministic, because python salts the hash
# values of some common types to protect against DOS attacks [2].
#
# [1]: https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer
# [2]: https://stackoverflow.com/questions/3848091/set-iteration-order-varies-from-run-to-run

for r in unique_ever_seen(self.in_type.representations):
p = 0
trivials = []

Expand Down
13 changes: 13 additions & 0 deletions escnn/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Iterable

def unique_ever_seen(iterable: Iterable) -> Iterable:
already_seen = set()

for item in iterable:
if item in already_seen:
continue
else:
already_seen.add(item)
yield item


26 changes: 22 additions & 4 deletions test/nn/test_basisexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,36 @@ def compare(self, basis: BlocksBasisExpansion):

for i, attr1 in enumerate(basis.get_basis_info()):
attr2 = basis.get_element_info(i)
self.assertEquals(attr1, attr2)
self.assertEquals(attr1['id'], i)
self.assertEqual(attr1, attr2)
self.assertEqual(attr1['id'], i)

for _ in range(5):
w = torch.randn(basis.dimension())

f1 = basis(w)
f2 = basis(w)
assert torch.allclose(f1, f2)
self.assertEquals(f1.shape[1], basis._input_size)
self.assertEquals(f1.shape[0], basis._output_size)
self.assertEqual(f1.shape[1], basis._input_size)
self.assertEqual(f1.shape[0], basis._output_size)


def test_checkpoint_meshgrid(self):
gs = rot3dOnR3()
so3 = gs.fibergroup

# I constructed this representation to trigger a bug where the
# `in_indices` and `out_indices` stored by the basis expansion module
# couldn't be restored from a checkpoint, due to the way `meshgrid()`
# was used internally.
ft = FieldType(gs, [so3.irrep(1), so3.irrep(0), so3.irrep(1)])

conv = R3Conv(ft, ft, kernel_size=3)

torch.save(conv.state_dict(), 'demo_conv.ckpt')
ckpt = torch.load('demo_conv.ckpt')

# This shouldn't raise.
conv.load_state_dict(ckpt)

if __name__ == '__main__':
unittest.main()
8 changes: 4 additions & 4 deletions test/nn/test_basissampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def compare(self, basis: BlocksBasisSampler, d: int):

for i, attr1 in enumerate(basis.get_basis_info()):
attr2 = basis.get_element_info(i)
self.assertEquals(attr1, attr2)
self.assertEquals(attr1['id'], i)
self.assertEqual(attr1, attr2)
self.assertEqual(attr1['id'], i)

for _ in range(5):
P = 20
Expand All @@ -150,8 +150,8 @@ def compare(self, basis: BlocksBasisSampler, d: int):
f1 = basis(w, edge_delta)
f2 = basis(w, edge_delta)
self.assertTrue(torch.allclose(f1, f2))
self.assertEquals(f1.shape[2], basis._input_size)
self.assertEquals(f1.shape[1], basis._output_size)
self.assertEqual(f1.shape[2], basis._input_size)
self.assertEqual(f1.shape[1], basis._output_size)

y1 = basis.compute_messages(w, x_j, edge_delta, conv_first=False)
y2 = basis.compute_messages(w, x_j, edge_delta, conv_first=True)
Expand Down
34 changes: 34 additions & 0 deletions test/nn/test_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,40 @@ def check_fieldnorm(self, ft: FieldType):
Standard deviations after normalization: \n {proj_std.cpu().numpy().reshape(-1)}
""")

def test_iid_bnorm_parameter_order(self):
gs = rot3dOnR3()
so3 = gs.fibergroup
ft = FieldType(gs, [so3.irrep(0), so3.irrep(1), so3.irrep(1)] * 2)

bn = IIDBatchNorm3d(ft)

param_names = [x for x, _ in bn.named_parameters()]
expected_order = [
'irrep_0_weight',
'irrep_0_bias',
'irrep_1_weight',
]

self.assertEqual(param_names, expected_order)

def test_gbnorm_parameter_order(self):
gs = rot3dOnR3()
so3 = gs.fibergroup
ft = FieldType(gs, [so3.irrep(0), so3.irrep(1), so3.irrep(1)] * 2)

bn = GNormBatchNorm(ft)

param_names = [x for x, _ in bn.named_parameters()]
expected_order = [
'irrep_0_weight',
'irrep_0_bias',
'irrep_1_weight',
'irrep_1_bias',
]

self.assertEqual(param_names, expected_order)



if __name__ == '__main__':
unittest.main()