@@ -66,11 +66,12 @@ def call(self, inputs, training=None, **kwargs):
6666
6767class TCN (layers .Layer ):
6868
69- def __init__ (self , filters , kernel_size , ** kwargs ):
69+ def __init__ (self , filters , kernel_size , return_sequence = False , ** kwargs ):
7070 super (TCN , self ).__init__ (** kwargs )
7171 self .blocks = []
7272 self .depth = len (filters )
7373 self .kernel_size = kernel_size
74+ self .return_sequence = return_sequence
7475
7576 for i in range (self .depth ):
7677 dilation_size = 2 ** i
@@ -81,24 +82,26 @@ def __init__(self, filters, kernel_size, **kwargs):
8182 name = f"residual_block_{ i } " )
8283 )
8384
84- self .slice_layer = layers .Lambda (lambda tt : tt [:, - 1 , :])
85+ if not self .return_sequence :
86+ self .slice_layer = layers .Lambda (lambda tt : tt [:, - 1 , :])
8587
8688 def call (self , inputs , training = None , ** kwargs ):
8789 x = inputs
8890 for block in self .blocks :
8991 x = block (x )
9092
91- x = self .slice_layer (x )
93+ if not self .return_sequence :
94+ x = self .slice_layer (x )
9295 return x
9396
9497 @property
9598 def receptive_field_size (self ):
9699 return 1 + 2 * (self .kernel_size - 1 ) * (2 ** self .depth - 1 )
97100
98101
99- def build_model (sequence_length , channels , filters , num_classes , kernel_size ):
102+ def build_model (sequence_length , channels , filters , num_classes , kernel_size , return_sequence = False ):
100103 inputs = Input (shape = (sequence_length , channels ), name = "inputs" )
101- tcn_block = TCN (filters , kernel_size )
104+ tcn_block = TCN (filters , kernel_size , return_sequence )
102105 x = tcn_block (inputs )
103106
104107 outputs = layers .Dense (num_classes ,
0 commit comments