-
Notifications
You must be signed in to change notification settings - Fork 1
/
models.py
264 lines (223 loc) · 10.7 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from moe_module import *
from itertools import combinations
class FlexMoE(nn.Module):
def __init__(self, num_modalities, full_modality_index, num_patches, hidden_dim, output_dim, num_layers, num_layers_pred, num_experts, num_routers, top_k, num_heads=2, dropout=0.5):
super(FlexMoE, self).__init__()
layers = []
_sparse = True
layers.append(TransformerEncoderLayer(num_experts, num_routers, hidden_dim, num_head=num_heads, dropout=dropout, hidden_times=2, mlp_sparse=_sparse, full_modality_index=full_modality_index, top_k=top_k))
for j in range(num_layers - 1):
_sparse = not _sparse
layers.append(TransformerEncoderLayer(num_experts, num_routers, hidden_dim, num_head=num_heads, dropout=dropout, hidden_times=2, mlp_sparse=_sparse, full_modality_index=full_modality_index, top_k=top_k))
layers.append(MLP(hidden_dim*num_modalities, hidden_dim, output_dim, num_layers_pred, activation=nn.ReLU(), dropout=0.5))
self.network = nn.Sequential(*layers)
self.pos_embed = nn.Parameter(torch.zeros(1, np.sum([num_patches]*num_modalities), hidden_dim))
self.combination_to_index = self._create_combination_index(num_modalities)
def forward(self, *inputs, expert_indices=None, is_full_modality=None):
chunk_size = [input.shape[1] for input in inputs]
x = torch.cat(inputs, dim=1)
if self.pos_embed != None:
x += self.pos_embed
x = torch.split(x, chunk_size, dim=1)
for i in range(len(self.network) - 1):
if expert_indices is not None and hasattr(self.network[i], 'set_expert_index'):
self.network[i].set_expert_index(expert_indices)
x = self.network[i](x)
x = [item.mean(dim=1) for item in x]
x = torch.cat(x, dim=1)
x = self.network[-1](x)
return x
def gate_loss(self):
g_loss = []
for mn, mm in self.named_modules():
# print(mn)
if hasattr(mm, 'all_gates'):
for i in range(len(mm.all_gates)):
i_loss = mm.all_gates[f'{i}'].get_loss()
if i_loss is None:
print(f"[WARN] The gate loss if {mn}, modality: {i} is emtpy, check weather call <get_loss> twice.")
else:
g_loss.append(i_loss)
return sum(g_loss)
def _create_combination_index(self, num_modalities):
combinations_list = []
for r in range(1, num_modalities + 1):
combinations_list.extend(combinations(range(num_modalities), r))
combination_to_index = {tuple(sorted(comb)): idx for idx, comb in enumerate(combinations_list)}
return combination_to_index
def assign_expert(self, combination):
index = self.combination_to_index.get(tuple(sorted(combination)))
return index
def set_full_modality(self, is_full_modality):
for layer in self.network:
if hasattr(layer, 'set_full_modality'):
layer.set_full_modality(is_full_modality)
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, activation=nn.ReLU(), dropout=0.5):
super(MLP, self).__init__()
layers = []
self.drop = nn.Dropout(dropout)
if num_layers == 1:
layers.append(nn.Linear(input_dim, output_dim))
else:
layers.append(nn.Linear(input_dim, hidden_dim))
layers.append(activation)
layers.append(self.drop)
for _ in range(num_layers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
layers.append(activation)
layers.append(self.drop)
layers.append(nn.Linear(hidden_dim, output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
class Custom3DCNN(nn.Module):
#Architecture provided by: End-To-End Alzheimer's Disease Diagnosis and Biomarker Identification
def __init__(self, hidden_dim=128):
super(Custom3DCNN, self).__init__()
self.conv1 = nn.Conv3d(1, 32, kernel_size=(3, 3, 3), stride=1, padding=1)
self.conv2 = nn.Conv3d(32, 32, kernel_size=(3, 3, 3), stride=1, padding=1)
self.pool1 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2)
self.dropout1 = nn.Dropout3d(0.2)
self.conv3 = nn.Conv3d(32, 64, kernel_size=(3, 3, 3), stride=1, padding=1)
self.conv4 = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), stride=1, padding=1)
self.pool2 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=3)
self.dropout2 = nn.Dropout3d(0.2)
self.conv5 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), stride=1, padding=1)
self.conv6 = nn.Conv3d(128, hidden_dim, kernel_size=(3, 3, 3), stride=1, padding=1)
self.pool3 = nn.MaxPool3d(kernel_size=(4, 4, 4))
self.dropout3 = nn.Dropout3d(0.2)
# Flatten the output and add a fully connected layer to reduce to hidden_dim
self.fc = nn.Linear(hidden_dim * 3 * 3 * 4, hidden_dim)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.dropout1(self.pool1(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = self.dropout2(self.pool2(x))
x = F.relu(self.conv5(x))
x = F.relu(self.conv6(x))
x = self.dropout3(self.pool3(x))
x = x.view(x.size(0), -1) # Flatten the output
x = self.fc(x) # Apply the fully connected layer
return x
class PatchEmbeddings(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(self, feature_size, num_patches, embed_dim, dropout=0.25):
super().__init__()
patch_size = math.ceil(feature_size / num_patches)
pad_size = num_patches*patch_size - feature_size
self.pad_size = pad_size
self.num_patches = num_patches
self.feature_size = feature_size
self.patch_size = patch_size
self.projection = nn.Linear(patch_size, embed_dim)
def forward(self, x):
x = F.pad(x, (0, self.pad_size)).view(x.shape[0], self.num_patches, self.patch_size)
x = self.projection(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(self,
num_experts,
num_routers,
d_model,
num_head,
dropout=0.1,
activation=nn.GELU,
hidden_times=2,
mlp_sparse = False,
self_attn = True,
full_modality_index=4,
top_k=2,
**kwargs) -> None:
super(TransformerEncoderLayer, self).__init__()
self.dropout = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation()
self.attn = Attention(
d_model, num_heads=num_head, qkv_bias=False, attn_drop=dropout, proj_drop=dropout)
self.mlp_sparse = mlp_sparse
self.self_attn = self_attn
self.expert_index = None
self.full_modality_index = full_modality_index
if self.mlp_sparse:
self.mlp = FMoETransformerMLP(num_expert=num_experts, n_router=num_routers, d_model=d_model, d_hidden=d_model * hidden_times, activation=nn.GELU(), top_k=top_k, **kwargs)
else:
self.mlp = MLP(input_dim=d_model, hidden_dim=d_model * hidden_times, output_dim=d_model, num_layers=2, activation=nn.GELU(), dropout=dropout)
def forward(self, x, attn_mask = None):
if self.self_attn:
chunk_size = [item.shape[1] for item in x]
x = self.norm1(torch.cat(x, dim=1))
kv = x
x = self.attn(x, kv, attn_mask)
x = x + self.dropout1(x)
x = torch.split(x, chunk_size, dim=1)
x = [item for item in x]
if self.mlp_sparse:
for i in range(len(chunk_size)):
x[i] = x[i] + self.dropout2(self.mlp(self.norm2(x[i]), self.expert_index))
else:
for i in range(len(chunk_size)):
x[i] = x[i] + self.dropout2(self.mlp(self.norm2(x[i])))
else:
chunk_size = [item.shape[1] for item in x]
x = [item for item in x]
for i in range(len(chunk_size)):
other_m = [x[j] for j in range(len(chunk_size)) if j != i]
other_m = torch.cat([x[i], *other_m], dim=1)
x[i] = self.attn(x[i], other_m, attn_mask)
x = [x[i]+self.dropout1(x[i]) for i in range(len(chunk_size))]
if self.mlp_sparse:
for i in range(len(chunk_size)):
x[i] = x[i] + self.dropout2(self.mlp(self.norm2(x[i]), self.expert_index))
else:
for i in range(len(chunk_size)):
x[i] = x[i] + self.dropout2(self.mlp(self.norm2(x[i])))
return x
def set_expert_index(self, expert_index):
self.expert_index = expert_index
def set_full_modality(self, is_full_modality):
if hasattr(self.mlp, 'set_full_modality'):
self.mlp.set_full_modality(is_full_modality)
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.head_dim = head_dim
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(head_dim * self.num_heads, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, kv, attn_mask=None):
# attn_mask: (B, N+1, N+1) input-dependent
eps = 1e-6
Bx, Nx, Cx = x.shape
B, N, C = kv.shape
q = self.q(x).reshape(Bx, Nx, self.num_heads, Cx//self.num_heads)
q = q.permute(0, 2, 1, 3)
kv = self.kv(kv)
kv = kv.reshape(B, N, 2, self.num_heads, C // self.num_heads)
kv = kv.permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, N+1, C/H) @ (B, H, C/H, N+1) -> (B, H, N+1, N+1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(Bx, Nx, -1) # (B, H, N+1, N+1) * (B, H, N+1, C/H) -> (B, H, N+1, C/H) -> (B, N+1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x