Skip to content

Commit 32d44b7

Browse files
Update SCCNet and EEGNex and enhance documentation (#742)
* updating the sccnet * updating the whats new * remove the commment * using same * expose more parameters * fixing typo * final fix for eeg net * done eegnex
1 parent ef4e298 commit 32d44b7

File tree

4 files changed

+87
-50
lines changed

4 files changed

+87
-50
lines changed

braindecode/models/eegnex.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# License: BSD (3-clause)
44
from __future__ import annotations
55

6+
import math
7+
68
import torch
79
import torch.nn as nn
810
from einops.layers.torch import Rearrange
@@ -74,6 +76,7 @@ def __init__(
7476
filter_1: int = 8,
7577
filter_2: int = 32,
7678
drop_prob: float = 0.5,
79+
kernel_block_1_2: int = 64,
7780
kernel_block_4: int = 16,
7881
dilation_block_4: int = 2,
7982
avg_pool_block4: int = 4,
@@ -99,32 +102,24 @@ def __init__(
99102
self.filter_3 = self.filter_2 * self.depth_multiplier
100103
self.drop_prob = drop_prob
101104
self.activation = activation
102-
105+
self.kernel_block_1_2 = (1, kernel_block_1_2)
103106
self.kernel_block_4 = (1, kernel_block_4)
104107
self.dilation_block_4 = (1, dilation_block_4)
105-
self.padding_block_4 = self._calc_padding(
106-
self.kernel_block_4, self.dilation_block_4
107-
)
108108
self.avg_pool_block4 = (1, avg_pool_block4)
109-
110109
self.kernel_block_5 = (1, kernel_block_5)
111110
self.dilation_block_5 = (1, dilation_block_5)
112-
113-
self.padding_block_5 = self._calc_padding(
114-
self.kernel_block_5, self.dilation_block_5
115-
)
116111
self.avg_pool_block5 = (1, avg_pool_block5)
117112

118113
# final layers output
119-
self.in_features = self.filter_1 * (self.n_times // self.filter_2)
114+
self.in_features = self._calculate_output_length()
120115

121116
# Following paper nomenclature
122117
self.block_1 = nn.Sequential(
123118
Rearrange("batch ch time -> batch 1 ch time"),
124119
nn.Conv2d(
125120
in_channels=1,
126121
out_channels=self.filter_1,
127-
kernel_size=(1, 64),
122+
kernel_size=self.kernel_block_1_2,
128123
padding="same",
129124
bias=False,
130125
),
@@ -135,7 +130,7 @@ def __init__(
135130
nn.Conv2d(
136131
in_channels=self.filter_1,
137132
out_channels=self.filter_2,
138-
kernel_size=(1, 64),
133+
kernel_size=self.kernel_block_1_2,
139134
padding="same",
140135
bias=False,
141136
),
@@ -166,7 +161,7 @@ def __init__(
166161
out_channels=self.filter_2,
167162
kernel_size=self.kernel_block_4,
168163
dilation=self.dilation_block_4,
169-
padding=self.padding_block_4,
164+
padding="same",
170165
bias=False,
171166
),
172167
nn.BatchNorm2d(num_features=self.filter_2),
@@ -178,7 +173,7 @@ def __init__(
178173
out_channels=self.filter_1,
179174
kernel_size=self.kernel_block_5,
180175
dilation=self.dilation_block_5,
181-
padding=self.padding_block_5,
176+
padding="same",
182177
bias=False,
183178
),
184179
nn.BatchNorm2d(num_features=self.filter_1),
@@ -226,26 +221,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
226221

227222
return x
228223

229-
@staticmethod
230-
def _calc_padding(
231-
kernel_size: tuple[int, int], dilation: tuple[int, int]
232-
) -> tuple[int, int]:
233-
"""
234-
Calculate padding size for 'same' convolution with dilation.
224+
def _calculate_output_length(self) -> int:
225+
# Pooling kernel sizes for the time dimension
226+
p4 = self.avg_pool_block4[1]
227+
p5 = self.avg_pool_block5[1]
235228

236-
Parameters
237-
----------
238-
kernel_size : tuple
239-
tuple containing the kernel size (height, width).
240-
dilation : tuple
241-
tuple containing the dilation rate (height, width).
229+
# Padding for the time dimension (assumed from padding=(0, 1))
230+
pad4 = 1
231+
pad5 = 1
242232

243-
Returns
244-
-------
245-
tuple
246-
Padding sizes (padding_height, padding_width).
247-
"""
248-
# Calculate padding
249-
padding_height = ((kernel_size[0] - 1) * dilation[0]) // 2
250-
padding_width = ((kernel_size[1] - 1) * dilation[1]) // 2
251-
return padding_height, padding_width
233+
# Stride is assumed to be equal to kernel size (p4 and p5)
234+
235+
# Calculate time dimension after block 3 pooling
236+
# Formula: floor((L_in + 2*padding - kernel_size) / stride) + 1
237+
T3 = math.floor((self.n_times + 2 * pad4 - p4) / p4) + 1
238+
239+
# Calculate time dimension after block 5 pooling
240+
T5 = math.floor((T3 + 2 * pad5 - p5) / p5) + 1
241+
242+
# Calculate final flattened features (channels * 1 * time_dim)
243+
# The spatial dimension is reduced to 1 after block 3's depthwise conv.
244+
final_in_features = (
245+
self.filter_1 * T5
246+
) # filter_1 is the number of channels before flatten
247+
return final_in_features

braindecode/models/sccnet.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282
n_spatial_filters_smooth: int = 20,
8383
drop_prob: float = 0.5,
8484
activation: nn.Module = LogActivation,
85+
batch_norm_momentum: float = 0.1,
8586
):
8687
super().__init__(
8788
n_outputs=n_outputs,
@@ -93,23 +94,15 @@ def __init__(
9394
)
9495
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
9596
# Parameters
96-
self.n_filters_spat = n_spatial_filters
97+
self.n_spatial_filters = n_spatial_filters
9798
self.n_spatial_filters_smooth = n_spatial_filters_smooth
9899
self.drop_prob = drop_prob
99100

100101
self.samples_100ms = int(math.floor(self.sfreq * 0.1))
101102
self.kernel_size_pool = int(self.sfreq * 0.5)
102103
# Equivalent to 0.5 seconds
103104

104-
# Compute the number of features for the final linear layer
105-
w_out_conv2 = (
106-
self.n_times - self.samples_100ms + 1 # After second conv layer
107-
)
108-
w_out_pool = (
109-
(w_out_conv2 - self.kernel_size_pool) // self.samples_100ms + 1
110-
# After pooling layer
111-
)
112-
num_features = self.n_spatial_filters_smooth * w_out_pool
105+
num_features = self._calc_num_features()
113106

114107
# Layers
115108
self.ensure_dim = Rearrange("batch nchan times -> batch 1 nchan times")
@@ -118,23 +111,27 @@ def __init__(
118111

119112
self.spatial_conv = nn.Conv2d(
120113
in_channels=1,
121-
out_channels=self.n_filters_spat,
114+
out_channels=self.n_spatial_filters,
122115
kernel_size=(self.n_chans, 1),
123116
)
124117

118+
self.spatial_batch_norm = nn.BatchNorm2d(
119+
self.n_spatial_filters, momentum=batch_norm_momentum
120+
)
121+
125122
self.permute = Rearrange(
126123
"batch filspat nchans time -> batch nchans filspat time"
127124
)
128125

129126
self.spatial_filt_conv = nn.Conv2d(
130127
in_channels=1,
131128
out_channels=self.n_spatial_filters_smooth,
132-
kernel_size=(self.n_filters_spat, self.samples_100ms),
133-
padding=0,
129+
kernel_size=(self.n_spatial_filters, self.samples_100ms),
134130
bias=False,
135131
)
136-
# Momentum following keras
137-
self.batch_norm = nn.BatchNorm2d(self.n_spatial_filters_smooth, momentum=0.9)
132+
self.batch_norm = nn.BatchNorm2d(
133+
self.n_spatial_filters_smooth, momentum=batch_norm_momentum
134+
)
138135

139136
self.dropout = nn.Dropout(self.drop_prob)
140137
self.temporal_smoothing = nn.AvgPool2d(
@@ -150,6 +147,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
150147
# Shape: (batch_size, 1, n_chans, n_times)
151148
x = self.spatial_conv(x)
152149
# Shape: (batch_size, n_filters, 1, n_times)
150+
x = self.spatial_batch_norm(x)
151+
# Shape: (batch_size, n_filters, 1, n_times)
153152
x = self.permute(x)
154153
# Shape: (batch_size, 1, n_filters, n_times)
155154
x = self.spatial_filt_conv(x)
@@ -169,3 +168,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
169168
x = self.final_layer(x)
170169
# Shape: (batch_size, n_outputs)
171170
return x
171+
172+
def _calc_num_features(self) -> int:
173+
# Compute the number of features for the final linear layer
174+
w_out_conv2 = (
175+
self.n_times - self.samples_100ms + 1 # After second conv layer
176+
)
177+
w_out_pool = (
178+
(w_out_conv2 - self.kernel_size_pool) // self.samples_100ms + 1
179+
# After pooling layer
180+
)
181+
num_features = self.n_spatial_filters_smooth * w_out_pool
182+
return num_features

docs/whats_new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Enhancements
8080

8181
Bugs
8282
~~~~
83+
- Making the :class:`braindecode.models.SCCNet` more compatible with paper instead of source code (:gh:`742` by `Bruno Aristimunha`_)
8384
- Making the :class:`braindecode.models.EEGNeX` and :class:`braindecode.models.CTNet` more compatible with paper instead of source code (:gh:`740` by `Bruno Aristimunha`_)
8485
- Exposing extra variable to avoid problem with the parallel process (:gh:`736` by `Pierre Guetschel`_)
8586
- Fixing the IFNet (:gh:`739` by `Bruno Aristimunha`_)

test/unit_tests/models/test_models.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
EEGNetv1,
3333
EEGNetv4,
3434
EEGResNet,
35+
EEGNeX,
3536
EEGSimpleConv,
3637
EEGTCNet,
3738
FBCNet,
@@ -1445,3 +1446,31 @@ def test_initialize_weights_conv():
14451446
assert conv.weight.std().item() <= 0.02 # Checking trunc_normal_ std
14461447
if conv.bias is not None:
14471448
assert torch.allclose(conv.bias, torch.zeros_like(conv.bias))
1449+
1450+
1451+
test_cases = [
1452+
pytest.param(64, id="n_times=64_perfect_multiple"),
1453+
pytest.param(437, id="n_times=437_trace_example"), # Expect 104
1454+
pytest.param(95, id="n_times=95_edge_case_1"), # Expect 24
1455+
pytest.param(67, id="n_times=67_edge_case_2"), # Expect 16
1456+
pytest.param(94, id="n_times=94_edge_case_3"), # Expect 24
1457+
]
1458+
1459+
@pytest.mark.parametrize("n_times_input", test_cases)
1460+
def test_eegnex_final_layer_in_features(n_times_input):
1461+
"""
1462+
Tests if the EEGNeX model correctly calculates the 'in_features'
1463+
for its final linear layer during initialization, especially for
1464+
n_times values that are not perfect multiples of pooling factors,
1465+
considering the specified padding.
1466+
"""
1467+
n_chans_test = 2
1468+
n_outputs_test = 5
1469+
1470+
model = EEGNeX(
1471+
n_chans=n_chans_test,
1472+
n_outputs=n_outputs_test,
1473+
n_times=n_times_input
1474+
)
1475+
1476+
print(model)

0 commit comments

Comments
 (0)