-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMLP.py
30 lines (26 loc) · 767 Bytes
/
MLP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
class MLP(nn.Module):
"""Multilayer perceptron.
Parameters: in_features, hidden_features, out_features, p
Attributes: fc, act, fc2, drop
"""
def __init__(self, in_features, hidden_features, out_features, p=0.):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(p)
def forward(self, x):
"""Run forward pass.
Parameters: x
Returns: torch.Tensor
"""
x = self.fc1(
x
)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x