Skip to content

Commit

Permalink
Feat: Add OrthLinear and OrthBilinear
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Mar 3, 2024
1 parent 0df3792 commit 64a46b3
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions torchglyph/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


class Linear(nn.Module):
def __init__(self, bias: bool = True, *, in_features: int,
out_features: int, leading_features: Tuple[int, ...] = ()) -> None:
def __init__(self, bias: bool = True, *, in_features: int, out_features: int,
leading_features: Tuple[int, ...] = ()) -> None:
super(Linear, self).__init__()

self.in_features = in_features
Expand All @@ -20,11 +20,15 @@ def __init__(self, bias: bool = True, *, in_features: int,
self.reset_parameters()

def extra_repr(self) -> str:
return ', '.join([
args = [
f'in_features={self.in_features}',
f'out_features={self.out_features}',
f'bias={self.bias is not None}',
])
]
if self.leading_features != ():
args.append(f'leading_features={self.leading_features}')

return ', '.join(args)

def reset_parameters(self) -> None:
bound = self.in_features ** -0.5
Expand All @@ -47,6 +51,13 @@ def reset_parameters(self) -> None:
init.zeros_(self.bias)


class OrthLinear(Linear):
def reset_parameters(self) -> None:
init.orthogonal_(self.weight)
if self.bias is not None:
init.zeros_(self.bias)


class ZeroLinear(Linear):
def reset_parameters(self) -> None:
init.zeros_(self.weight)
Expand All @@ -55,8 +66,8 @@ def reset_parameters(self) -> None:


class Bilinear(nn.Module):
def __init__(self, bias: bool = True, *, in_features1: int, in_features2: int,
out_features: int, leading_features: Tuple[int, ...] = ()) -> None:
def __init__(self, bias: bool = True, *, in_features1: int, in_features2: int, out_features: int,
leading_features: Tuple[int, ...] = ()) -> None:
super(Bilinear, self).__init__()

self.in_features1 = in_features1
Expand All @@ -70,12 +81,16 @@ def __init__(self, bias: bool = True, *, in_features1: int, in_features2: int,
self.reset_parameters()

def extra_repr(self) -> str:
return ', '.join([
args = [
f'in_features1={self.in_features1}',
f'in_features2={self.in_features2}',
f'out_features={self.out_features}',
f'bias={self.bias is not None}',
])
]
if self.leading_features != ():
args.append(f'leading_features={self.leading_features}')

return ', '.join(args)

def reset_parameters(self) -> None:
bound = max(self.in_features1, self.in_features2) ** -0.5
Expand All @@ -98,6 +113,13 @@ def reset_parameters(self) -> None:
init.zeros_(self.bias)


class OrthBilinear(Bilinear):
def reset_parameters(self) -> None:
init.orthogonal_(self.weight)
if self.bias is not None:
init.zeros_(self.bias)


class ZeroBilinear(Bilinear):
def reset_parameters(self) -> None:
init.zeros_(self.weight)
Expand All @@ -108,11 +130,13 @@ def reset_parameters(self) -> None:
Proj = Union[
Type[Linear],
Type[NormLinear],
Type[OrthLinear],
Type[ZeroLinear],
]

Biproj = Union[
Type[Bilinear],
Type[NormBilinear],
Type[OrthBilinear],
Type[ZeroBilinear],
]

0 comments on commit 64a46b3

Please sign in to comment.