@@ -164,10 +164,11 @@ def _regressor_heuristic(id: str, model_size: ModelSizeOrUnits, dim_input: int,
164164def _flat_context_heuristic (id : str , model_size : ModelSizeOrUnits , dim_input : int ) -> list [int ]:
165165 if isinstance (model_size , dict ):
166166 return model_size [id ]
167- model_size_layers = dict (S = [8 ], M = [64 ], L = [128 ])
167+ model_size_layers = dict (S = [4 ], M = [16 ], L = [64 ])
168168 layers = model_size_layers [model_size ]
169169 coefficient = round (np .log (max (dim_input , np .e )))
170170 dims = [unit * coefficient for unit in layers ]
171+ _LOG .info (f"[ARGN] flat context heuristic: { dim_input = } -> { dims } " )
171172 return dims
172173
173174
@@ -176,10 +177,11 @@ def _sequential_context_heuristic(
176177) -> list [int ]:
177178 if isinstance (model_size , dict ):
178179 return model_size [id ]
179- model_size_layers = dict (S = [8 ], M = [32 ], L = [64 , 64 ])
180+ model_size_layers = dict (S = [4 ], M = [16 ], L = [64 , 64 ])
180181 layers = model_size_layers [model_size ]
181182 coefficient = round (np .log (max (dim_input * seq_len_median , np .e )))
182183 dims = [unit * coefficient for unit in layers ]
184+ _LOG .info (f"[ARGN] sequential context heuristic: { dim_input = } x { seq_len_median = } -> { dims } " )
183185 return dims
184186
185187
@@ -190,6 +192,7 @@ def _history_heuristic(id: str, model_size: ModelSizeOrUnits, dim_input: int, se
190192 layers = model_size_layers [model_size ]
191193 coefficient = round (np .log (max (dim_input * seq_len_median , np .e )))
192194 dims = [unit * coefficient for unit in layers ]
195+ _LOG .info (f"[ARGN] history heuristic: { dim_input = } x { seq_len_median = } -> { dims } " )
193196 return dims
194197
195198
@@ -425,6 +428,7 @@ def forward(self, x) -> list[torch.Tensor]:
425428 xs = torch .cat (list (embeddings .values ()), dim = - 1 )
426429 for compressor_layer in self .get ():
427430 xs = compressor_layer (xs )
431+ xs = torch .relu (xs )
428432 xs = self .dropout (xs )
429433 flat_context = [xs ]
430434 return flat_context
0 commit comments