From 897b004ab900c4095e37463ead2e0c9301c1dc71 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 3 Nov 2022 17:45:33 +0100 Subject: [PATCH 1/4] Template code for new structured couplings --- FrEIA/core/__init__.py | 2 + FrEIA/core/invertible.py | 23 ++++++++++++ FrEIA/flows/__init__.py | 0 FrEIA/flows/base.py | 31 +++++++++++++++ FrEIA/splits/__init__.py | 2 + FrEIA/splits/base.py | 15 ++++++++ FrEIA/splits/even.py | 12 ++++++ FrEIA/splits/random.py | 3 ++ FrEIA/transforms/__init__.py | 2 + FrEIA/transforms/affine.py | 25 ++++++++++++ FrEIA/transforms/base.py | 15 ++++++++ FrEIA/transforms/coupling.py | 73 ++++++++++++++++++++++++++++++++++++ FrEIA/transforms/identity.py | 12 ++++++ FrEIA/transforms/ode.py | 59 +++++++++++++++++++++++++++++ 14 files changed, 274 insertions(+) create mode 100644 FrEIA/core/__init__.py create mode 100644 FrEIA/core/invertible.py create mode 100644 FrEIA/flows/__init__.py create mode 100644 FrEIA/flows/base.py create mode 100644 FrEIA/splits/__init__.py create mode 100644 FrEIA/splits/base.py create mode 100644 FrEIA/splits/even.py create mode 100644 FrEIA/splits/random.py create mode 100644 FrEIA/transforms/__init__.py create mode 100644 FrEIA/transforms/affine.py create mode 100644 FrEIA/transforms/base.py create mode 100644 FrEIA/transforms/coupling.py create mode 100644 FrEIA/transforms/identity.py create mode 100644 FrEIA/transforms/ode.py diff --git a/FrEIA/core/__init__.py b/FrEIA/core/__init__.py new file mode 100644 index 0000000..9ce0f83 --- /dev/null +++ b/FrEIA/core/__init__.py @@ -0,0 +1,2 @@ + +from .invertible import Invertible diff --git a/FrEIA/core/invertible.py b/FrEIA/core/invertible.py new file mode 100644 index 0000000..6a173fd --- /dev/null +++ b/FrEIA/core/invertible.py @@ -0,0 +1,23 @@ + +from abc import ABC +import torch.nn as nn + +from typing import Any + +from typing import TypeVar + +T = TypeVar("T") + + +class Invertible(ABC, nn.Module): + def forward(self, *args: T, **kwargs: T) -> Any: + raise NotImplementedError + + def inverse(self, *args, **kwargs): + raise NotImplementedError + + def __call__(self, *args, rev = False, **kwargs): + if not rev: + return self.forward(*args, **kwargs) + + return self.inverse(*args, **kwargs) diff --git a/FrEIA/flows/__init__.py b/FrEIA/flows/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/FrEIA/flows/base.py b/FrEIA/flows/base.py new file mode 100644 index 0000000..7961acd --- /dev/null +++ b/FrEIA/flows/base.py @@ -0,0 +1,31 @@ + +from freia.core import Invertible + +class Flow(Invertible): + def __init__(self, transform, distribution): + self.transform = transform + self.distribution = distribution + + def forward(self, x): + z, logdet = self.transform.forward(x) + + logp = self.distribution.log_prob(z) + + nll = -(logp + logdet) + + return z, nll + + def sample_transform(self, size, temperature): + z = self.distribution.sample(size, temperature) + + x, _ = self.transform.inverse(z) + + return x + + +class RecurrentFlow(Flow): + def forward(self, x): + z = x + logdet = None + for t in range(...): + z, logdet = self.transform.forward(z, t) \ No newline at end of file diff --git a/FrEIA/splits/__init__.py b/FrEIA/splits/__init__.py new file mode 100644 index 0000000..7dbb05f --- /dev/null +++ b/FrEIA/splits/__init__.py @@ -0,0 +1,2 @@ + +from .even import EvenSplit diff --git a/FrEIA/splits/base.py b/FrEIA/splits/base.py new file mode 100644 index 0000000..f4bfc6a --- /dev/null +++ b/FrEIA/splits/base.py @@ -0,0 +1,15 @@ + +from freia.core import Invertible + +import torch + + +class Split(Invertible): + def __init__(self, dim: int = 1): + self.dim = dim + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def inverse(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + raise NotImplementedError diff --git a/FrEIA/splits/even.py b/FrEIA/splits/even.py new file mode 100644 index 0000000..53aea8a --- /dev/null +++ b/FrEIA/splits/even.py @@ -0,0 +1,12 @@ + +from .base import Split + +import torch + + +class EvenSplit(Split): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return torch.split(x, 2, dim=1) + + def inverse(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return torch.cat((x1, x2), dim=1) diff --git a/FrEIA/splits/random.py b/FrEIA/splits/random.py new file mode 100644 index 0000000..fda598e --- /dev/null +++ b/FrEIA/splits/random.py @@ -0,0 +1,3 @@ + + +# class RandomSplit() diff --git a/FrEIA/transforms/__init__.py b/FrEIA/transforms/__init__.py new file mode 100644 index 0000000..44c2a6e --- /dev/null +++ b/FrEIA/transforms/__init__.py @@ -0,0 +1,2 @@ + +from .base import Transform diff --git a/FrEIA/transforms/affine.py b/FrEIA/transforms/affine.py new file mode 100644 index 0000000..34324d0 --- /dev/null +++ b/FrEIA/transforms/affine.py @@ -0,0 +1,25 @@ + +from .base import Transform + +import torch + +from .coupling import CouplingTransform + + +class AffineTransform(CouplingTransform): + + def __init__(self): + parameter_counts = {...} + super().__init__(parameter_counts=parameter_counts) + + def transform_parameters(self, **parameters): + parameters["a"] = torch.exp(parameters["a"]) + + def _forward(self, x: torch.Tensor, **parameters) -> torch.Tensor: + parameters = self.get_parameters() + a, b = parameters["a"], parameters["b"] + return a * x + b, torch.log(a) + + def _inverse(self, z: torch.Tensor, **parameters) -> torch.Tensor: + a, b = parameters["a"], parameters["b"] + return (z - b) / a, -torch.log(a) diff --git a/FrEIA/transforms/base.py b/FrEIA/transforms/base.py new file mode 100644 index 0000000..2eae5f4 --- /dev/null +++ b/FrEIA/transforms/base.py @@ -0,0 +1,15 @@ + +from freia.core import Invertible + +import torch + + +WithJacobian = tuple[torch.Tensor, torch.Tensor] + + +class Transform(Invertible): + def forward(self, x: torch.Tensor, *, condition: torch.Tensor, **kwargs) -> WithJacobian: + raise NotImplementedError + + def inverse(self, z: torch.Tensor, **parameters: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/FrEIA/transforms/coupling.py b/FrEIA/transforms/coupling.py new file mode 100644 index 0000000..f8c521b --- /dev/null +++ b/FrEIA/transforms/coupling.py @@ -0,0 +1,73 @@ + +from .base import Transform +from freia.splits import EvenSplit + +import torch +import torch.nn as nn + + +class Spline(Transform): + def __init__(self, affine, inner_spline): + ... + + def forward(self, x: torch.Tensor, *, condition: torch.Tensor, **kwargs) -> WithJacobian: + x[out] = affine(x[out]) + x[out] = inner_spline(x[out]) + + +class Spline(CouplingTransform): + def _forward(self): + x[in] = self._spline(...) + x[out] = self._affine(...) + + + +class CouplingTransform(Transform): + def __init__(self, transform1, transform2, subnet_constructor, split=EvenSplit(dim=1)): + self.split = split + self.subnet1 = subnet_constructor(...) + self.subnet2 = subnet_constructor(...) + + def split_parameters(self, parameters: torch.Tensor) -> dict: + pc = self.parameter_counts + parameters = torch.split(parameters, list(pc.values()), dim=1) + + return dict(zip(pc.keys(), parameters)) + + def transform_parameters(self, parameters: dict[torch.Tensor]) -> None: + pass + + def get_parameters(self, *args, **kwargs) -> dict: + raise NotImplementedError + + def get_parameters(self, u: torch.Tensor, subnet: nn.Module) -> dict: + + parameters = subnet(u) + parameters = self.split_parameters(parameters) + should_be_none = self.transform_parameters(**parameters) + if should_be_none is not None: + warnings.warn(...) + + return parameters + + + def forward(self, x: torch.Tensor, **parameters: torch.Tensor) -> torch.Tensor: + x1, x2 = self.split.forward(x) + + + + parameters = self.get_parameters(u=x2, subnet=self.subnet1) + z1, logdet1 = self.transform1.forward(x1, **parameters) + parameters = self.get_parameters(u=z1, subnet=self.subnet2) + z2, logdet2 = self.transform2(x2, **parameters) + + z = self.split.inverse(z1, z2) + logdet = logdet1 + logdet2 + + return z, logdet + + + + + +my_single_coupling = CouplingTransform(transform1=AffineTransform(...), transform2=None) diff --git a/FrEIA/transforms/identity.py b/FrEIA/transforms/identity.py new file mode 100644 index 0000000..af14d2c --- /dev/null +++ b/FrEIA/transforms/identity.py @@ -0,0 +1,12 @@ + +from .base import Transform + +import torch + + +class IdentityTransform(Transform): + def forward(self, x: torch.Tensor, **parameters: torch.Tensor) -> WithJacobian: + return x, 0 + + def inverse(self, z: torch.Tensor, **parameters: torch.Tensor) -> WithJacobian: + return z, 0 diff --git a/FrEIA/transforms/ode.py b/FrEIA/transforms/ode.py new file mode 100644 index 0000000..4f658f7 --- /dev/null +++ b/FrEIA/transforms/ode.py @@ -0,0 +1,59 @@ + +from .base import Transform + +import torch + +from scipy.ode import solve_ode + + + +class Parameterized(nn.Module): + def __init__(self, *, subnet_constructor, parameter_counts): + super().__init__() + self.subnet = ... + self.parameter_counts = ... + self.transform = transform_cls + + def __call__(self, *args, **kwargs): + self.transform = transform_cls(*args, **kwargs) + + return self + + def forward(self): + parameters = self.subnet(...) + return self.transform(x, parameters) + + +@Parameterized +class ODETransform(Transform): + def __init__(self, integration_steps: int = 10): + super().__init__() + self.integration_steps = integration_steps + + def forward(self, x: torch.Tensor, **parameters) -> tuple[torch.Tensor, torch.Tensor]: + return euler(x, v, dt) + + # ode integration + dt = 1 / self.integration_steps + for _ in range(self.integration_steps): + parameters = self.get_parameters() + v = parameters["v"] + x = euler(x, v, dt) + + return x + +ODETransform = Parameterized(ODETransform) + + + + + +ode = ODETransform() + + + + +def euler(x, v, dt): + return x + v * dt + + From da5ea1ac67d4fd4c5eaaa1ef7ce3546a1c69c2d8 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Thu, 10 Nov 2022 19:45:59 +0100 Subject: [PATCH 2/4] Ideas for parameters and couplings --- FrEIA/splits/base.py | 7 +- FrEIA/splits/even.py | 4 +- FrEIA/transforms/base.py | 35 +++++++++- playground.py | 136 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 177 insertions(+), 5 deletions(-) create mode 100644 playground.py diff --git a/FrEIA/splits/base.py b/FrEIA/splits/base.py index f4bfc6a..faf61d0 100644 --- a/FrEIA/splits/base.py +++ b/FrEIA/splits/base.py @@ -1,14 +1,17 @@ -from freia.core import Invertible +from FrEIA.core import Invertible + +from typing import Tuple import torch class Split(Invertible): def __init__(self, dim: int = 1): + super().__init__() self.dim = dim - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError def inverse(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: diff --git a/FrEIA/splits/even.py b/FrEIA/splits/even.py index 53aea8a..9feaf2c 100644 --- a/FrEIA/splits/even.py +++ b/FrEIA/splits/even.py @@ -1,11 +1,13 @@ from .base import Split +from typing import Tuple + import torch class EvenSplit(Split): - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return torch.split(x, 2, dim=1) def inverse(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: diff --git a/FrEIA/transforms/base.py b/FrEIA/transforms/base.py index 2eae5f4..341ec21 100644 --- a/FrEIA/transforms/base.py +++ b/FrEIA/transforms/base.py @@ -7,9 +7,40 @@ WithJacobian = tuple[torch.Tensor, torch.Tensor] + class Transform(Invertible): - def forward(self, x: torch.Tensor, *, condition: torch.Tensor, **kwargs) -> WithJacobian: + def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError - def inverse(self, z: torch.Tensor, **parameters: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def inverse(self, z: torch.Tensor) -> torch.Tensor: raise NotImplementedError + + + + +@Parameterized(scale=1, shift=1) +class AffineTransform(Transform): + def forward(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: + return scale * x + shift + + def inverse(self, z: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: + return (z - shift) / scale + + +# class SplineTransform(Transform): +# def forward(self, x: torch.Tensor, edges: torch.Tensor): +# assert edges.shape == (..., self.bins) +# pass + + +class Parameterized: + def __init__(self, **parameter_counts): + self.parameter_counts = parameter_counts + + def __call__(self, cls): + + cls.forward = forward + cls.inverse = inverse + + + diff --git a/playground.py b/playground.py new file mode 100644 index 0000000..55539e5 --- /dev/null +++ b/playground.py @@ -0,0 +1,136 @@ + +import torch + +import torch.distributions + +from typing import Dict + + +class Transform: + def __init__(self): + print(f"{self.__class__.__name__} __init__") + + def __call__(self, *args, **kwargs): + print(f"{self.__class__.__name__} __call__") + + +class Parameter: + def __init__(self, count: int): + self.count = count + + def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class Real(Parameter): + def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor: + return unconstrained + + +class Positive(Parameter): + def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor: + return torch.exp(unconstrained) + + +class Increasing(Parameter): + def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor: + return unconstrained[:, 0] + torch.cumsum(torch.exp(unconstrained[:, 1:]), dim=1) + + + +class Coupling(Transform): + def __init__(self, split, transform, subnet, **parameters: Parameter): + super().__init__() + self.split = split + self.transform = transform + # TODO: 2 subnets? or just singular coupling? + self.subnet = subnet + self._parameters = parameters + + @property + def parameter_names(self): + return self._parameters.keys() + + @property + def parameter_counts(self): + return [p.count for p in self._parameters.values()] + + def get_parameters(self, condition: torch.Tensor) -> Dict[str, torch.Tensor]: + parameters = self.subnet(condition) + parameters = torch.split(parameters, self.parameter_counts, dim=1) + parameters = [p.constrain(u) for (p, u) in zip(self._parameters.keys(), parameters)] + parameters = dict(zip(self.parameter_names, parameters)) + + return parameters + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = self.split.forward(x) + parameters = self.get_parameters(x2) + z1 = self.transform.forward(x1, **parameters) + parameters = self.get_parameters(z1) + z2 = self.transform.forward(x2, **parameters) + + z = self.split.inverse(z1, z2) + + return z + + def inverse(self, z: torch.Tensor) -> torch.Tensor: + z1, z2 = self.split.forward(z) + parameters = self.get_parameters(z1) + x2 = self.transform.inverse(z2, **parameters) + parameters = self.get_parameters(x2) + x1 = self.transform.inverse(z1, **parameters) + + x = self.split.inverse(x1, x2) + + return x + + +from FrEIA.splits import EvenSplit + + +def parameterize(**parameters): + def wrap(cls): + def construct(*args, split=EvenSplit(), subnet_constructor, **kwargs): + transform = cls(*args, **kwargs) + dims_in = ... + dims_out = ... + subnet = subnet_constructor(dims_in, dims_out) + + return Coupling(split=split, transform=transform, subnet=subnet, **parameters) + return construct + + return wrap + + +@parameterize(scale=Positive(1), shift=Real(1)) +class AffineTransform(Transform): + def forward(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: + return scale * x + shift + + def inverse(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: + return (x - shift) / scale + + +# TODO: runtime parameter counts? spline is kinda the only coupling that needs this +@parameterize(x_edges=Increasing(10), y_edges=Increasing(10), deltas=Positive(8)) +class RQSpline(Transform): + pass + + +def subnet_constructor(dims_in, dims_out): + print("subnet_constructor") + def subnet(x): + return torch.zeros(x.shape[0], dims_out) + + return subnet + + +t = AffineTransform(subnet_constructor=subnet_constructor) + +print(type(t)) + +x = None + +t(x) + From 5815db2d7e18c340aaca7c22f1c4dfd4a2e4cce4 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Thu, 10 Nov 2022 22:22:46 +0100 Subject: [PATCH 3/4] Dynamic Parameterization --- playground.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/playground.py b/playground.py index 55539e5..5ab4b0d 100644 --- a/playground.py +++ b/playground.py @@ -14,10 +14,16 @@ def __call__(self, *args, **kwargs): print(f"{self.__class__.__name__} __call__") +from typing import Callable, Union + class Parameter: - def __init__(self, count: int): + def __init__(self, count: Union[int, Callable[[Transform], int]]): self.count = count + def initialize(self, transform: Transform): + if isinstance(self.count, Callable): + self.count = self.count(transform) + def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -93,10 +99,15 @@ def parameterize(**parameters): def wrap(cls): def construct(*args, split=EvenSplit(), subnet_constructor, **kwargs): transform = cls(*args, **kwargs) + for p in parameters.values(): + # initialize dynamic parameters + p.initialize(transform) + dims_in = ... dims_out = ... subnet = subnet_constructor(dims_in, dims_out) + return Coupling(split=split, transform=transform, subnet=subnet, **parameters) return construct @@ -112,10 +123,17 @@ def inverse(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> return (x - shift) / scale -# TODO: runtime parameter counts? spline is kinda the only coupling that needs this -@parameterize(x_edges=Increasing(10), y_edges=Increasing(10), deltas=Positive(8)) +@parameterize(x_edges=Increasing(lambda t: t.bins), y_edges=Increasing(lambda t: t.bins), deltas=Increasing(lambda t: t.bins - 1)) class RQSpline(Transform): - pass + def __init__(self, bins: int): + super().__init__() + self.bins = bins + + def forward(self, x: torch.Tensor, x_edges: torch.Tensor, y_edges: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + return torch.zeros(*x.shape) + + def inverse(self): + pass def subnet_constructor(dims_in, dims_out): @@ -134,3 +152,10 @@ def subnet(x): t(x) + +t = RQSpline(bins=8, subnet_constructor=subnet_constructor) + +t(x) + +print(t.parameter_counts) + From bd466f9ea93ae74ec215a17cc284dc1a59cdf4f9 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Fri, 11 Nov 2022 17:57:52 +0100 Subject: [PATCH 4/4] post-discussion --- playground.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/playground.py b/playground.py index 5ab4b0d..e716785 100644 --- a/playground.py +++ b/playground.py @@ -1,3 +1,4 @@ +from functools import wraps import torch @@ -51,6 +52,17 @@ def __init__(self, split, transform, subnet, **parameters: Parameter): self.transform = transform # TODO: 2 subnets? or just singular coupling? self.subnet = subnet + + test_subnet(...) + try: + self.subnet[-1].weight.data.zero_() + self.subnet[-1].bias.data.zero_() + except Exception: + dummy_output = ... + + if not zero: + warnings.warn(...) + self._parameters = parameters @property @@ -69,10 +81,16 @@ def get_parameters(self, condition: torch.Tensor) -> Dict[str, torch.Tensor]: return parameters - def forward(self, x: torch.Tensor) -> torch.Tensor: + def transform_forward(self): + pass + + def transform_inverse(self): + pass + + def forward(self, x: torch.Tensor, rev: bool = False, jac: bool = True) -> torch.Tensor: x1, x2 = self.split.forward(x) parameters = self.get_parameters(x2) - z1 = self.transform.forward(x1, **parameters) + z1 = self.transform_forward(x1, **parameters) parameters = self.get_parameters(z1) z2 = self.transform.forward(x2, **parameters) @@ -97,6 +115,7 @@ def inverse(self, z: torch.Tensor) -> torch.Tensor: def parameterize(**parameters): def wrap(cls): + @wraps(cls) def construct(*args, split=EvenSplit(), subnet_constructor, **kwargs): transform = cls(*args, **kwargs) for p in parameters.values(): @@ -114,24 +133,29 @@ def construct(*args, split=EvenSplit(), subnet_constructor, **kwargs): return wrap -@parameterize(scale=Positive(1), shift=Real(1)) -class AffineTransform(Transform): - def forward(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: +# @parameterize(scale=Positive(1), shift=Real(1)) +class AffineTransform(Coupling): + def __init__(self): + super().__init__(scale=Positive(1), shift=Real(1)) + def _forward(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor, rev) -> torch.Tensor: return scale * x + shift - def inverse(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: + def _inverse(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: return (x - shift) / scale @parameterize(x_edges=Increasing(lambda t: t.bins), y_edges=Increasing(lambda t: t.bins), deltas=Increasing(lambda t: t.bins - 1)) class RQSpline(Transform): def __init__(self, bins: int): - super().__init__() + super().__init__(x_edges=Increasing(bins)) self.bins = bins def forward(self, x: torch.Tensor, x_edges: torch.Tensor, y_edges: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: return torch.zeros(*x.shape) + def blub(self): + pass + def inverse(self): pass @@ -145,8 +169,8 @@ def subnet(x): t = AffineTransform(subnet_constructor=subnet_constructor) - -print(type(t)) +t.blub() +print(type(t.transform.blub())) x = None