|
187 | 187 | " max_prediction_length=prediction_length,\n", |
188 | 188 | ")\n", |
189 | 189 | "\n", |
190 | | - "validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training_cutoff + 1)\n", |
| 190 | + "validation = TimeSeriesDataSet.from_dataset(\n", |
| 191 | + " training, data, min_prediction_idx=training_cutoff + 1\n", |
| 192 | + ")\n", |
191 | 193 | "batch_size = 128\n", |
192 | | - "train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)\n", |
193 | | - "val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)" |
| 194 | + "train_dataloader = training.to_dataloader(\n", |
| 195 | + " train=True, batch_size=batch_size, num_workers=0\n", |
| 196 | + ")\n", |
| 197 | + "val_dataloader = validation.to_dataloader(\n", |
| 198 | + " train=False, batch_size=batch_size, num_workers=0\n", |
| 199 | + ")" |
194 | 200 | ] |
195 | 201 | }, |
196 | 202 | { |
|
251 | 257 | }, |
252 | 258 | { |
253 | 259 | "cell_type": "code", |
254 | | - "execution_count": 6, |
| 260 | + "execution_count": null, |
255 | 261 | "metadata": {}, |
256 | 262 | "outputs": [ |
257 | 263 | { |
|
269 | 275 | "source": [ |
270 | 276 | "pl.seed_everything(42)\n", |
271 | 277 | "trainer = pl.Trainer(accelerator=\"auto\", gradient_clip_val=0.1)\n", |
272 | | - "net = NBeats.from_dataset(training, learning_rate=3e-2, weight_decay=1e-2, widths=[32, 512], backcast_loss_ratio=0.1)" |
| 278 | + "net = NBeats.from_dataset(\n", |
| 279 | + " training,\n", |
| 280 | + " learning_rate=3e-2,\n", |
| 281 | + " weight_decay=1e-2,\n", |
| 282 | + " widths=[32, 512],\n", |
| 283 | + " backcast_loss_ratio=0.1,\n", |
| 284 | + ")" |
273 | 285 | ] |
274 | 286 | }, |
275 | 287 | { |
276 | 288 | "cell_type": "code", |
277 | | - "execution_count": 7, |
| 289 | + "execution_count": null, |
278 | 290 | "metadata": {}, |
279 | 291 | "outputs": [ |
280 | 292 | { |
|
321 | 333 | ], |
322 | 334 | "source": [ |
323 | 335 | "# find optimal learning rate\n", |
324 | | - "from lightning.pytorch.tuner import Tuner\n", |
| 336 | + "from pytorch_forecasting.tuning import Tuner\n", |
325 | 337 | "\n", |
326 | | - "res = Tuner(trainer).lr_find(net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5)\n", |
| 338 | + "res = Tuner(trainer).lr_find(\n", |
| 339 | + " net,\n", |
| 340 | + " train_dataloaders=train_dataloader,\n", |
| 341 | + " val_dataloaders=val_dataloader,\n", |
| 342 | + " min_lr=1e-5,\n", |
| 343 | + ")\n", |
327 | 344 | "print(f\"suggested learning rate: {res.suggestion()}\")\n", |
328 | 345 | "fig = res.plot(show=True, suggest=True)\n", |
329 | 346 | "fig.show()\n", |
|
340 | 357 | }, |
341 | 358 | { |
342 | 359 | "cell_type": "code", |
343 | | - "execution_count": 14, |
| 360 | + "execution_count": null, |
344 | 361 | "metadata": {}, |
345 | 362 | "outputs": [ |
346 | 363 | { |
|
443 | 460 | } |
444 | 461 | ], |
445 | 462 | "source": [ |
446 | | - "early_stop_callback = EarlyStopping(monitor=\"val_loss\", min_delta=1e-4, patience=10, verbose=False, mode=\"min\")\n", |
| 463 | + "early_stop_callback = EarlyStopping(\n", |
| 464 | + " monitor=\"val_loss\", min_delta=1e-4, patience=10, verbose=False, mode=\"min\"\n", |
| 465 | + ")\n", |
447 | 466 | "trainer = pl.Trainer(\n", |
448 | 467 | " max_epochs=3,\n", |
449 | 468 | " accelerator=\"auto\",\n", |
|
481 | 500 | }, |
482 | 501 | { |
483 | 502 | "cell_type": "code", |
484 | | - "execution_count": 15, |
| 503 | + "execution_count": null, |
485 | 504 | "metadata": {}, |
486 | 505 | "outputs": [], |
487 | 506 | "source": [ |
|
645 | 664 | ], |
646 | 665 | "source": [ |
647 | 666 | "for idx in range(10): # plot 10 examples\n", |
648 | | - " best_model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True)" |
| 667 | + " best_model.plot_prediction(\n", |
| 668 | + " raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True\n", |
| 669 | + " )" |
649 | 670 | ] |
650 | 671 | }, |
651 | 672 | { |
|
0 commit comments