-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmlp.py
74 lines (67 loc) · 3.21 KB
/
mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import jittor as jt
import jittor.nn as nn
from wn_linear import WNLinear
class MLPforNeuralSDF(nn.Module):
def __init__(self, layer_dims, skip_connection=[], activ=None, use_layernorm=False, use_weightnorm=False,
geometric_init=False, out_bias=0., invert=False):
"""Initialize a multi-layer perceptron with skip connection.
Args:
layer_dims: A list of integers representing the number of channels in each layer.
skip_connection: A list of integers representing the index of layers to add skip connection.
"""
super().__init__()
self.skip_connection = skip_connection
self.use_layernorm = use_layernorm
self.linears = nn.ModuleList()
if use_layernorm:
self.layer_norm = nn.ModuleList()
# Hidden layers
layer_dim_pairs = list(zip(layer_dims[:-1], layer_dims[1:]))
for li, (k_in, k_out) in enumerate(layer_dim_pairs):
if li in self.skip_connection:
k_in += layer_dims[0]
linear = nn.Linear(k_in, k_out)
if geometric_init:
self._geometric_init(linear, k_in, k_out, first=(li == 0),
skip_dim=(layer_dims[0] if li in self.skip_connection else 0))
if use_weightnorm:
linear = WNLinear(linear)
self.linears.append(linear)
if use_layernorm and li != len(layer_dim_pairs) - 1:
self.layer_norm.append(nn.LayerNorm(k_out))
if li == len(layer_dim_pairs) - 1:
jt.nn.init.constant_(self.linears[-1].bias, 0.0)
# SDF prediction layer
self.linear_sdf = nn.Linear(k_in, 1)
if geometric_init:
self._geometric_init_sdf(self.linear_sdf, k_in, out_bias=out_bias, invert=invert)
self.activ = activ or jt.nn.relu
def execute(self, input, with_sdf=True, with_feat=True):
feat = input
for li, linear in enumerate(self.linears):
if li in self.skip_connection:
feat = jt.concat([feat, input], dim=-1)
if li != len(self.linears) - 1 or with_feat:
feat_pre = linear(feat)
if self.use_layernorm:
feat_pre = self.layer_norm[li](feat_pre)
feat_activ = self.activ(feat_pre)
if li == len(self.linears) - 1:
out = [self.linear_sdf(feat) if with_sdf else None,
feat_activ if with_feat else None]
feat = feat_activ
return out
def _geometric_init(self, linear, k_in, k_out, first=False, skip_dim=0):
nn.init.constant_(linear.bias, 0.0)
nn.init.gauss_(linear.weight, 0.0, np.sqrt(2 / k_out))
if first:
nn.init.constant_(linear.weight[:, 3:], 0.0) # positional encodings
if skip_dim:
nn.init.constant_(linear.weight[:, -skip_dim:], 0.0) # skip connections
def _geometric_init_sdf(self, linear, k_in, out_bias=0., invert=False):
nn.init.gauss_(linear.weight, mean=np.sqrt(np.pi / k_in), std=0.0001)
nn.init.constant_(linear.bias, -out_bias)
if invert:
linear.weight.data *= -1
linear.bias.data *= -1