Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] multiprocessing bug with pickle #69

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

mathjax_config = {
mathjax_path = 'https://cdn.jsdelivr.net/npm/mathjax@2/MathJax.js?config=TeX-AMS-MML_HTMLorMML'
mathjax3_config = {
"TeX": {
#"packages": {'[+]': ['bm']},
"Macros": {
Expand Down
64 changes: 56 additions & 8 deletions e2cnn/group/groups/cyclicgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def subgroup(self, id: int) -> Tuple[Group, Callable, Callable]:
# take the elements of the group generated by "r^ratio"
sg = CyclicGroup(order)

parent_mapping = lambda e, ratio=ratio: e * ratio
child_mapping = lambda e, ratio=ratio: None if e % ratio != 0 else int(e // ratio)
parent_mapping = ParentMapping(ratio=ratio)
child_mapping = ChildMapping(ratio=ratio)

self._subgroups[id] = sg, parent_mapping, child_mapping

Expand Down Expand Up @@ -260,17 +260,14 @@ def irrep(self, k: int) -> IrreducibleRepresentation:
if k == 0:
# Trivial representation

irrep = lambda element, identity=np.eye(1): identity
character = lambda e: 1
irrep = get_trivial_representation
supported_nonlinearities = ['pointwise', 'gate', 'norm', 'gated', 'concatenated']
self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 1, 1,
supported_nonlinearities=supported_nonlinearities,
# character=character,
# trivial=True,
frequency=k)
elif n % 2 == 0 and k == int(n/2):
# 1 dimensional Irreducible representation (only for even order groups)
irrep = lambda element, k=k, base_angle=base_angle: np.array([[np.cos(k * element * base_angle)]])
irrep = OneDimRepresentation(k, base_angle)
supported_nonlinearities = ['norm', 'gated', 'concatenated']
self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 1, 1,
supported_nonlinearities=supported_nonlinearities,
Expand All @@ -279,7 +276,7 @@ def irrep(self, k: int) -> IrreducibleRepresentation:
# 2 dimensional Irreducible Representations

# build the rotation matrix with rotation frequency 'frequency'
irrep = lambda element, k=k, base_angle=base_angle: utils.psi(element * base_angle, k=k)
irrep = TwoDimRepresentation(k, base_angle)

supported_nonlinearities = ['norm', 'gated']
self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 2, 2,
Expand All @@ -295,3 +292,54 @@ def _generator(N: int) -> 'CyclicGroup':

return _cached_group_instances[N]

# parent_mapping = lambda e, ratio=ratio: e * ratio
# child_mapping = lambda e, ratio=ratio: None if e % ratio != 0 else int(e // ratio)

class ParentMapping:

def __init__(self, ratio):
self.ratio = ratio

def __call__(self, e, ratio=None):
ratio = ratio if bool(ratio) else self.ratio
return e * ratio


class ChildMapping:

def __init__(self, ratio):
self.ratio = ratio

def __call__(self, e, ratio=None):
ratio = ratio if bool(ratio) else self.ratio
return None if e % ratio != 0 else int(e // ratio)



def get_trivial_representation(element, identity=np.eye(1)):
return identity


class OneDimRepresentation:

def __init__(self, k, base_angle):
self.k = k
self.base_angle = base_angle

def __call__(self, element, k=None, base_angle=None):
k = k if bool(k) else self.k
base_angle = base_angle if bool(base_angle) else self.base_angle

return np.array([[np.cos(k * element * base_angle)]])


class TwoDimRepresentation:

def __init__(self, k, base_angle):
self.k = k
self.base_angle = base_angle

def __call__(self, element, k=None, base_angle=None):
k = k if bool(k) else self.k
base_angle = base_angle if bool(base_angle) else self.base_angle
return utils.psi(element * base_angle, k=k)
44 changes: 30 additions & 14 deletions e2cnn/group/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,14 +737,6 @@ def build_regular_representation(group: e2cnn.group.Group) -> Tuple[List[e2cnn.g
change_of_basis_inv = change_of_basis.T

return irreps, change_of_basis, change_of_basis_inv

# return Representation(group,
# "regular",
# [r.name for r in irreps],
# change_of_basis,
# ['pointwise', 'norm', 'gated', 'concatenated'],
# representation=representation,
# change_of_basis_inv=change_of_basis_inv)


def build_quotient_representation(group: e2cnn.group.Group,
Expand Down Expand Up @@ -1048,10 +1040,36 @@ def direct_sum_factory(irreps: List[e2cnn.group.IrreducibleRepresentation],

unique_irreps = list({irr.name: irr for irr in irreps}.items())
irreps_names = [irr.name for irr in irreps]

def direct_sum(element,
irreps_names=irreps_names, change_of_basis=change_of_basis,
change_of_basis_inv=change_of_basis_inv, unique_irreps=unique_irreps):

return DirectSum(
irreps_names=irreps_names,
change_of_basis=change_of_basis,
change_of_basis_inv=change_of_basis_inv,
unique_irreps=unique_irreps
)


class DirectSum:

def __init__(self, irreps_names, change_of_basis, change_of_basis_inv, unique_irreps):
self.irreps_names = irreps_names
self.change_of_basis = change_of_basis
self.change_of_basis_inv = change_of_basis_inv
self.unique_irreps = unique_irreps

def __call__(
self,
element,
irreps_names=None,
change_of_basis=None,
change_of_basis_inv=None,
unique_irreps=None
):
irreps_names = irreps_names if bool(irreps_names) else self.irreps_names
change_of_basis = change_of_basis if bool(change_of_basis) else self.change_of_basis
change_of_basis_inv = change_of_basis_inv if bool(change_of_basis_inv) else self.change_of_basis_inv
unique_irreps = unique_irreps if bool(unique_irreps) else self.unique_irreps

reprs = {}
for n, irr in unique_irreps:
reprs[n] = irr(element)
Expand All @@ -1064,8 +1082,6 @@ def direct_sum(element,
P = sparse.block_diag(blocks, format='csc')

return change_of_basis @ P @ change_of_basis_inv

return direct_sum


def null(A: Union[np.matrix, sparse.linalg.LinearOperator],
Expand Down
2 changes: 1 addition & 1 deletion e2cnn/gspaces/r2/general_r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, fibergroup: Group, name: str):
# Store the computed intertwiners between irreps
# - key = (filter size, sigma, rings)
# - value = dictionary mapping (input_irrep, output_irrep) pairs to the corresponding basis
self._irreps_intertwiners_basis_memory = defaultdict(lambda: dict())
self._irreps_intertwiners_basis_memory = defaultdict(dict)

# Store the computed intertwiners between general representations
# - key = (filter size, sigma, rings)
Expand Down
2 changes: 1 addition & 1 deletion e2cnn/kernels/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def sample_masked(self, points: np.ndarray, mask: np.ndarray, out: np.ndarray =
"""

assert mask.shape == (self.dim, )
assert mask.dtype == np.bool
assert mask.dtype == bool

basis = self.sample(points)

Expand Down
2 changes: 1 addition & 1 deletion e2cnn/nn/modules/r2upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self,
However, ``"nearest"`` is not equivariant; using this method may result in broken equivariance.
For this reason, we suggest to use ``"bilinear"`` (default value).

..warning ::
.. warning ::
The module supports a ``size`` parameter as an alternative to ``scale_factor``.
However, the use of ``scale_factor`` should be *preferred*, since it guarantees both axes are scaled
uniformly, which preserves rotation equivariance.
Expand Down