Skip to content

Commit fd6d7b0

Browse files
committed
perf: refine context heuristics
1 parent f1a884f commit fd6d7b0

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

mostlyai/engine/_tabular/argn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,11 @@ def _regressor_heuristic(id: str, model_size: ModelSizeOrUnits, dim_input: int,
164164
def _flat_context_heuristic(id: str, model_size: ModelSizeOrUnits, dim_input: int) -> list[int]:
165165
if isinstance(model_size, dict):
166166
return model_size[id]
167-
model_size_layers = dict(S=[8], M=[64], L=[128])
167+
model_size_layers = dict(S=[4], M=[16], L=[64])
168168
layers = model_size_layers[model_size]
169169
coefficient = round(np.log(max(dim_input, np.e)))
170170
dims = [unit * coefficient for unit in layers]
171+
_LOG.info(f"[ARGN] flat context heuristic: {dim_input=} -> {dims}")
171172
return dims
172173

173174

@@ -176,10 +177,11 @@ def _sequential_context_heuristic(
176177
) -> list[int]:
177178
if isinstance(model_size, dict):
178179
return model_size[id]
179-
model_size_layers = dict(S=[8], M=[32], L=[64, 64])
180+
model_size_layers = dict(S=[4], M=[16], L=[64, 64])
180181
layers = model_size_layers[model_size]
181182
coefficient = round(np.log(max(dim_input * seq_len_median, np.e)))
182183
dims = [unit * coefficient for unit in layers]
184+
_LOG.info(f"[ARGN] sequential context heuristic: {dim_input=} x {seq_len_median=} -> {dims}")
183185
return dims
184186

185187

@@ -190,6 +192,7 @@ def _history_heuristic(id: str, model_size: ModelSizeOrUnits, dim_input: int, se
190192
layers = model_size_layers[model_size]
191193
coefficient = round(np.log(max(dim_input * seq_len_median, np.e)))
192194
dims = [unit * coefficient for unit in layers]
195+
_LOG.info(f"[ARGN] history heuristic: {dim_input=} x {seq_len_median=} -> {dims}")
193196
return dims
194197

195198

@@ -425,6 +428,7 @@ def forward(self, x) -> list[torch.Tensor]:
425428
xs = torch.cat(list(embeddings.values()), dim=-1)
426429
for compressor_layer in self.get():
427430
xs = compressor_layer(xs)
431+
xs = torch.relu(xs)
428432
xs = self.dropout(xs)
429433
flat_context = [xs]
430434
return flat_context

0 commit comments

Comments
 (0)