diff --git a/docs/source/conf.py b/docs/source/conf.py index 8a0977c2..01d80fd6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -94,7 +94,8 @@ # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] -mathjax_config = { +mathjax_path = 'https://cdn.jsdelivr.net/npm/mathjax@2/MathJax.js?config=TeX-AMS-MML_HTMLorMML' +mathjax3_config = { "TeX": { #"packages": {'[+]': ['bm']}, "Macros": { diff --git a/e2cnn/group/groups/cyclicgroup.py b/e2cnn/group/groups/cyclicgroup.py index 0e5c2665..abcaa142 100644 --- a/e2cnn/group/groups/cyclicgroup.py +++ b/e2cnn/group/groups/cyclicgroup.py @@ -149,8 +149,8 @@ def subgroup(self, id: int) -> Tuple[Group, Callable, Callable]: # take the elements of the group generated by "r^ratio" sg = CyclicGroup(order) - parent_mapping = lambda e, ratio=ratio: e * ratio - child_mapping = lambda e, ratio=ratio: None if e % ratio != 0 else int(e // ratio) + parent_mapping = ParentMapping(ratio=ratio) + child_mapping = ChildMapping(ratio=ratio) self._subgroups[id] = sg, parent_mapping, child_mapping @@ -260,17 +260,14 @@ def irrep(self, k: int) -> IrreducibleRepresentation: if k == 0: # Trivial representation - irrep = lambda element, identity=np.eye(1): identity - character = lambda e: 1 + irrep = get_trivial_representation supported_nonlinearities = ['pointwise', 'gate', 'norm', 'gated', 'concatenated'] self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 1, 1, supported_nonlinearities=supported_nonlinearities, - # character=character, - # trivial=True, frequency=k) elif n % 2 == 0 and k == int(n/2): # 1 dimensional Irreducible representation (only for even order groups) - irrep = lambda element, k=k, base_angle=base_angle: np.array([[np.cos(k * element * base_angle)]]) + irrep = OneDimRepresentation(k, base_angle) supported_nonlinearities = ['norm', 'gated', 'concatenated'] self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 1, 1, supported_nonlinearities=supported_nonlinearities, @@ -279,7 +276,7 @@ def irrep(self, k: int) -> IrreducibleRepresentation: # 2 dimensional Irreducible Representations # build the rotation matrix with rotation frequency 'frequency' - irrep = lambda element, k=k, base_angle=base_angle: utils.psi(element * base_angle, k=k) + irrep = TwoDimRepresentation(k, base_angle) supported_nonlinearities = ['norm', 'gated'] self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 2, 2, @@ -295,3 +292,54 @@ def _generator(N: int) -> 'CyclicGroup': return _cached_group_instances[N] + # parent_mapping = lambda e, ratio=ratio: e * ratio + # child_mapping = lambda e, ratio=ratio: None if e % ratio != 0 else int(e // ratio) + +class ParentMapping: + + def __init__(self, ratio): + self.ratio = ratio + + def __call__(self, e, ratio=None): + ratio = ratio if bool(ratio) else self.ratio + return e * ratio + + +class ChildMapping: + + def __init__(self, ratio): + self.ratio = ratio + + def __call__(self, e, ratio=None): + ratio = ratio if bool(ratio) else self.ratio + return None if e % ratio != 0 else int(e // ratio) + + + +def get_trivial_representation(element, identity=np.eye(1)): + return identity + + +class OneDimRepresentation: + + def __init__(self, k, base_angle): + self.k = k + self.base_angle = base_angle + + def __call__(self, element, k=None, base_angle=None): + k = k if bool(k) else self.k + base_angle = base_angle if bool(base_angle) else self.base_angle + + return np.array([[np.cos(k * element * base_angle)]]) + + +class TwoDimRepresentation: + + def __init__(self, k, base_angle): + self.k = k + self.base_angle = base_angle + + def __call__(self, element, k=None, base_angle=None): + k = k if bool(k) else self.k + base_angle = base_angle if bool(base_angle) else self.base_angle + return utils.psi(element * base_angle, k=k) \ No newline at end of file diff --git a/e2cnn/group/representation.py b/e2cnn/group/representation.py index 1647f252..60fb51b7 100644 --- a/e2cnn/group/representation.py +++ b/e2cnn/group/representation.py @@ -737,14 +737,6 @@ def build_regular_representation(group: e2cnn.group.Group) -> Tuple[List[e2cnn.g change_of_basis_inv = change_of_basis.T return irreps, change_of_basis, change_of_basis_inv - - # return Representation(group, - # "regular", - # [r.name for r in irreps], - # change_of_basis, - # ['pointwise', 'norm', 'gated', 'concatenated'], - # representation=representation, - # change_of_basis_inv=change_of_basis_inv) def build_quotient_representation(group: e2cnn.group.Group, @@ -1048,10 +1040,36 @@ def direct_sum_factory(irreps: List[e2cnn.group.IrreducibleRepresentation], unique_irreps = list({irr.name: irr for irr in irreps}.items()) irreps_names = [irr.name for irr in irreps] - - def direct_sum(element, - irreps_names=irreps_names, change_of_basis=change_of_basis, - change_of_basis_inv=change_of_basis_inv, unique_irreps=unique_irreps): + + return DirectSum( + irreps_names=irreps_names, + change_of_basis=change_of_basis, + change_of_basis_inv=change_of_basis_inv, + unique_irreps=unique_irreps + ) + + +class DirectSum: + + def __init__(self, irreps_names, change_of_basis, change_of_basis_inv, unique_irreps): + self.irreps_names = irreps_names + self.change_of_basis = change_of_basis + self.change_of_basis_inv = change_of_basis_inv + self.unique_irreps = unique_irreps + + def __call__( + self, + element, + irreps_names=None, + change_of_basis=None, + change_of_basis_inv=None, + unique_irreps=None + ): + irreps_names = irreps_names if bool(irreps_names) else self.irreps_names + change_of_basis = change_of_basis if bool(change_of_basis) else self.change_of_basis + change_of_basis_inv = change_of_basis_inv if bool(change_of_basis_inv) else self.change_of_basis_inv + unique_irreps = unique_irreps if bool(unique_irreps) else self.unique_irreps + reprs = {} for n, irr in unique_irreps: reprs[n] = irr(element) @@ -1064,8 +1082,6 @@ def direct_sum(element, P = sparse.block_diag(blocks, format='csc') return change_of_basis @ P @ change_of_basis_inv - - return direct_sum def null(A: Union[np.matrix, sparse.linalg.LinearOperator], diff --git a/e2cnn/gspaces/r2/general_r2.py b/e2cnn/gspaces/r2/general_r2.py index b0f6f8a0..d83a55c2 100644 --- a/e2cnn/gspaces/r2/general_r2.py +++ b/e2cnn/gspaces/r2/general_r2.py @@ -34,7 +34,7 @@ def __init__(self, fibergroup: Group, name: str): # Store the computed intertwiners between irreps # - key = (filter size, sigma, rings) # - value = dictionary mapping (input_irrep, output_irrep) pairs to the corresponding basis - self._irreps_intertwiners_basis_memory = defaultdict(lambda: dict()) + self._irreps_intertwiners_basis_memory = defaultdict(dict) # Store the computed intertwiners between general representations # - key = (filter size, sigma, rings) diff --git a/e2cnn/kernels/basis.py b/e2cnn/kernels/basis.py index 50c5e8a1..edd6cbc9 100644 --- a/e2cnn/kernels/basis.py +++ b/e2cnn/kernels/basis.py @@ -124,7 +124,7 @@ def sample_masked(self, points: np.ndarray, mask: np.ndarray, out: np.ndarray = """ assert mask.shape == (self.dim, ) - assert mask.dtype == np.bool + assert mask.dtype == bool basis = self.sample(points) diff --git a/e2cnn/nn/modules/r2upsampling.py b/e2cnn/nn/modules/r2upsampling.py index 369c95e2..7d827543 100644 --- a/e2cnn/nn/modules/r2upsampling.py +++ b/e2cnn/nn/modules/r2upsampling.py @@ -35,7 +35,7 @@ def __init__(self, However, ``"nearest"`` is not equivariant; using this method may result in broken equivariance. For this reason, we suggest to use ``"bilinear"`` (default value). - ..warning :: + .. warning :: The module supports a ``size`` parameter as an alternative to ``scale_factor``. However, the use of ``scale_factor`` should be *preferred*, since it guarantees both axes are scaled uniformly, which preserves rotation equivariance.