Skip to content

Commit 687902d

Browse files
committed
added tests for theta forecaster
1 parent c1f1767 commit 687902d

File tree

4 files changed

+42
-13
lines changed

4 files changed

+42
-13
lines changed

ads/opctl/operator/lowcode/forecast/model/theta.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ads.opctl import logger
1818
from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorConfig
1919
from ads.opctl.operator.lowcode.forecast.utils import (_label_encode_dataframe)
20+
from ads.opctl.operator.lowcode.common.utils import seconds_to_datetime
2021

2122
from ..const import (
2223
SupportedModels, ForecastOutputColumns, DEFAULT_TRIALS,
@@ -39,6 +40,8 @@ def freq_to_sp(freq: str) -> int | None:
3940
return 4
4041
if freq == "A" or freq == "Y": # Annual
4142
return 1 # Usually no seasonality
43+
if freq.startswith("W"): # Weekly data (W, W-SUN, W-MON, etc.)
44+
return 52
4245

4346
# Weekly data
4447
if freq == "D": # Daily
@@ -102,6 +105,8 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
102105
data_i = self.drop_horizon(data)
103106
target = self.spec.target_column
104107
freq = pd.infer_freq(data_i.index)
108+
if freq.startswith("W-"):
109+
freq = "W"
105110
data_i = data_i.asfreq(freq)
106111
y = data_i[target]
107112

@@ -244,15 +249,18 @@ def _generate_report(self):
244249
import report_creator as rc
245250
"""The method that needs to be implemented on the particular model level."""
246251
all_sections = []
252+
theta_blocks = []
253+
247254
for series_id, sm in self.models.items():
248255
model = sm["model"]
256+
249257
# ---- Extract details from ThetaModel ----
250258
fitted_params = model.get_fitted_params()
251259
alpha = fitted_params.get("initial_level", None)
260+
smoothing_level = fitted_params.get("smoothing_level", None)
252261
sp = model.sp
253262
deseasonalize_model = model.deseasonalize_model
254263
desasonalized = model.deseasonalize
255-
smoothing_level = fitted_params.get("smoothing_level", None)
256264
n_obs = len(model._y) if hasattr(model, "_y") else "N/A"
257265

258266
# Date range
@@ -263,7 +271,7 @@ def _generate_report(self):
263271
start_date = ""
264272
end_date = ""
265273

266-
# ---- Build the text block ----
274+
# ---- Build the DF ----
267275
meta_df = pd.DataFrame({
268276
"Metric": [
269277
"Alpha / Initial Level",
@@ -273,7 +281,7 @@ def _generate_report(self):
273281
"Deseasonalization Method",
274282
"Period (sp)",
275283
"Sample Start",
276-
"Sample End"
284+
"Sample End",
277285
],
278286
"Value": [
279287
alpha,
@@ -283,18 +291,31 @@ def _generate_report(self):
283291
deseasonalize_model,
284292
sp,
285293
start_date,
286-
end_date
287-
]
294+
end_date,
295+
],
288296
})
289297

290-
# ---- Add to Report Creator ----
291-
theta_section = rc.Block(
292-
rc.Heading(f"Theta Model Summary — {series_id}", level=2),
293-
rc.Text("This section provides detailed ThetaModel fit diagnostics."),
298+
# ---- Create a block (NOT a section directly) ----
299+
theta_block = rc.Block(
300+
rc.Heading(f"Theta Model Summary", level=3),
294301
rc.DataTable(meta_df),
302+
label=series_id
295303
)
296304

297-
all_sections.append(theta_section)
305+
# Add with optional label support
306+
theta_blocks.append(
307+
theta_block
308+
)
309+
310+
# ---- Combine into final section like ARIMA example ----
311+
theta_title = rc.Heading("Theta Model Parameters", level=2)
312+
313+
if len(theta_blocks) > 1:
314+
theta_section = rc.Select(blocks=theta_blocks)
315+
else:
316+
theta_section = theta_blocks[0]
317+
318+
all_sections.extend([theta_title, theta_section])
298319

299320
if self.spec.generate_explanations:
300321
try:
@@ -383,14 +404,19 @@ def get_explain_predict_fn(self, series_id):
383404
def _custom_predict(
384405
data,
385406
model=self.models[series_id]["model"],
407+
dt_column_name=self.datasets._datetime_column_name,
386408
target_col=self.original_target_column,
387409
):
388410
"""
389411
data: ForecastDatasets.get_data_at_series(s_id)
390412
"""
391413
data = data.drop([target_col], axis=1)
414+
data[dt_column_name] = seconds_to_datetime(
415+
data[dt_column_name], dt_format=self.spec.datetime_column.format
416+
)
392417
data = self.preprocess(data, series_id)
393-
fh = ForecastingHorizon(pd.to_datetime(data.index), is_relative=False)
418+
h = len(data)
419+
fh = ForecastingHorizon(np.arange(1, h + 1), is_relative=True)
394420
return model.predict(fh)
395421

396422
return _custom_predict

tests/operators/forecast/test_datasets.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"prophet",
3333
"neuralprophet",
3434
"autots",
35+
"theta",
3536
# "lgbforecast",
3637
"auto-select",
3738
"auto-select-series",
@@ -241,7 +242,7 @@ def test_pandas_to_historical(model):
241242
check_output_for_errors(output_data_path)
242243

243244

244-
@pytest.mark.parametrize("model", ["prophet", "neuralprophet"])
245+
@pytest.mark.parametrize("model", ["prophet", "neuralprophet", "theta"])
245246
def test_pandas_to_historical_test(model):
246247
df = pd.read_csv(f"{DATASET_PREFIX}dataset4.csv")
247248
df_train = df[:-PERIODS]
@@ -268,7 +269,7 @@ def test_pandas_to_historical_test(model):
268269

269270

270271
# CostAD
271-
@pytest.mark.parametrize("model", ["prophet", "neuralprophet"])
272+
@pytest.mark.parametrize("model", ["prophet", "neuralprophet", "theta"])
272273
def test_pandas_to_historical_test2(model):
273274
df = pd.read_csv(f"{DATASET_PREFIX}dataset5.csv")
274275
df_train = df[:-1]

tests/operators/forecast/test_errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@
142142
"automlx",
143143
"prophet",
144144
"neuralprophet",
145+
"theta",
145146
"autots",
146147
# "lgbforecast",
147148
]

tests/operators/forecast/test_explainers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# "automlx", # FIXME: automlx is failing, no errors
2121
"prophet",
2222
"neuralprophet",
23+
"theta",
2324
"auto-select-series",
2425
]
2526

0 commit comments

Comments
 (0)