Skip to content

Commit

Permalink
update skim imprementation
Browse files Browse the repository at this point in the history
  • Loading branch information
YoshikiMas committed Mar 18, 2023
1 parent 6a9e504 commit c11b9ba
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 30 deletions.
2 changes: 2 additions & 0 deletions asteroid/masknn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .convolutional import TDConvNet, TDConvNetpp, SuDORMRF, SuDORMRFImproved
from .recurrent import DPRNN, LSTMMasker
from .attention import DPTransformer
from .skim import SkiM

__all__ = [
"TDConvNet",
"DPRNN",
"SkiM"
"DPTransformer",
"LSTMMasker",
"SuDORMRF",
Expand Down
47 changes: 17 additions & 30 deletions asteroid/masknn/skim.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,22 +244,15 @@ def forward(self, x):
assert H == self.seg_input_size

# Construct chunks
if self.chunk_size == self.hop_size:
x, rest = self._padfeature(x)
x = x.permute(0, 2, 1).contiguous()
x = x.view(B, -1, self.chunk_size, H) # [B, S (number of chunks), chunk_size, H]
S = x.shape[1] # number of chunks

else:
x = unfold(
x.unsqueeze(-1),
kernel_size=(self.chunk_size, 1),
padding=(self.chunk_size, 0),
stride=(self.hop_size, 1),
)
S = x.shape[-1] # number of chunks
x = x.reshape(B, H, self.chunk_size, S)
x = x.permute(0, 3, 2, 1).contiguous() # [B, S (number of chunks), chunk_size, H]
x = unfold(
x.unsqueeze(-1),
kernel_size=(self.chunk_size, 1),
padding=(self.chunk_size, 0),
stride=(self.hop_size, 1),
)
S = x.shape[-1] # number of chunks
x = x.reshape(B, H, self.chunk_size, S)
x = x.permute(0, 3, 2, 1).contiguous() # [B, S (number of chunks), chunk_size, H]

# Main SkiM processing
x = x.view(B*S, self.chunk_size, H).contiguous()
Expand All @@ -273,21 +266,15 @@ def forward(self, x):

# Reconstruct from chunks
x = x.permute(0, 3, 2, 1).contiguous() # [B, H, chunk_size, S (number of chunks)]
if self.chunk_size == self.hop_size:
x = x.view(B, H, self.chunk_size*S)
x = x[..., :T]

else:
x = fold(
x.reshape(B, self.chunk_size*H, S),
(T, 1),
kernel_size=(self.chunk_size, 1),
padding=(self.chunk_size, 0),
stride=(self.hop_size, 1),
)
x = x[..., 0]
x = fold(
x.reshape(B, self.chunk_size*H, S),
(T, 1),
kernel_size=(self.chunk_size, 1),
padding=(self.chunk_size, 0),
stride=(self.hop_size, 1),
)

return x
return x[..., 0]

def _padfeature(self, x):
B, H, T = x.size()
Expand Down
2 changes: 2 additions & 0 deletions asteroid/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .lstm_tasnet import LSTMTasNet
from .demask import DeMask
from .x_umx import XUMX
from .skim_tasnet import SkiMTasNet

# Sharing-related
from .publisher import save_publishable, upload_publishable
Expand All @@ -26,6 +27,7 @@
"DCUNet",
"DCCRNet",
"XUMX",
"SkiMTasNet",
"save_publishable",
"upload_publishable",
]
Expand Down

0 comments on commit c11b9ba

Please sign in to comment.