Skip to content

Commit 31e54e6

Browse files
committed
Update linting to py310 target version
1 parent a5e806a commit 31e54e6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+468
-485
lines changed

pytorch_forecasting/_registry/_lookup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def all_objects(
160160
ROOT = str(Path(__file__).parent.parent) # package root directory
161161

162162
def _coerce_to_str(obj):
163-
if isinstance(obj, (list, tuple)):
163+
if isinstance(obj, list | tuple):
164164
return [_coerce_to_str(o) for o in obj]
165165
if isclass(obj):
166166
obj = obj.get_tag("object_type")

pytorch_forecasting/base/_base_pkg.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ class Base_pkg(_BasePtForecasterV2):
4242

4343
def __init__(
4444
self,
45-
model_cfg: Optional[Union[dict[str, Any], str, Path]] = None,
46-
trainer_cfg: Optional[Union[dict[str, Any], str, Path]] = None,
47-
datamodule_cfg: Optional[Union[dict[str, Any], str, Path]] = None,
48-
ckpt_path: Optional[Union[str, Path]] = None,
45+
model_cfg: dict[str, Any] | str | Path | None = None,
46+
trainer_cfg: dict[str, Any] | str | Path | None = None,
47+
datamodule_cfg: dict[str, Any] | str | Path | None = None,
48+
ckpt_path: str | Path | None = None,
4949
):
5050
self.ckpt_path = Path(ckpt_path) if ckpt_path else None
5151
self.model_cfg = self._load_config(
@@ -74,9 +74,9 @@ def __init__(
7474

7575
@staticmethod
7676
def _load_config(
77-
config: Union[dict, str, Path, None],
78-
ckpt_path: Optional[Union[str, Path]] = None,
79-
auto_file_name: Optional[str] = None,
77+
config: dict | str | Path | None,
78+
ckpt_path: str | Path | None = None,
79+
auto_file_name: str | None = None,
8080
) -> dict:
8181
"""
8282
Loads configuration from a dictionary, YAML file, or Pickle file.
@@ -157,7 +157,7 @@ def _build_datamodule(self, data: TimeSeries) -> LightningDataModule:
157157
return datamodule_cls(data, **self.datamodule_cfg)
158158

159159
def _load_dataloader(
160-
self, data: Union[TimeSeries, LightningDataModule, DataLoader]
160+
self, data: TimeSeries | LightningDataModule | DataLoader
161161
) -> DataLoader:
162162
"""Converts various data input types into a DataLoader for prediction."""
163163
if isinstance(data, TimeSeries): # D1 Layer
@@ -191,11 +191,11 @@ def _save_artifact(self, output_dir: Path):
191191

192192
def fit(
193193
self,
194-
data: Union[TimeSeries, LightningDataModule],
194+
data: TimeSeries | LightningDataModule,
195195
# todo: we should create a base data_module for different data_modules
196196
save_ckpt: bool = True,
197-
ckpt_dir: Union[str, Path] = "checkpoints",
198-
ckpt_kwargs: Optional[dict[str, Any]] = None,
197+
ckpt_dir: str | Path = "checkpoints",
198+
ckpt_kwargs: dict[str, Any] | None = None,
199199
**trainer_fit_kwargs,
200200
):
201201
"""
@@ -265,10 +265,10 @@ def fit(
265265

266266
def predict(
267267
self,
268-
data: Union[TimeSeries, LightningDataModule, DataLoader],
269-
output_dir: Optional[Union[str, Path]] = None,
268+
data: TimeSeries | LightningDataModule | DataLoader,
269+
output_dir: str | Path | None = None,
270270
**kwargs,
271-
) -> Union[dict[str, torch.Tensor], None]:
271+
) -> dict[str, torch.Tensor] | None:
272272
"""
273273
Generate predictions by wrapping the model's predict method.
274274

pytorch_forecasting/callbacks/predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class PredictCallback(BasePredictionWriter):
2929
def __init__(
3030
self,
3131
mode: str = "prediction",
32-
return_info: Optional[list[str]] = None,
32+
return_info: list[str] | None = None,
3333
mode_kwargs: dict[str, Any] = None,
3434
):
3535
super().__init__(write_interval="epoch")

pytorch_forecasting/data/_tslib_data_module.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries
2121
from pytorch_forecasting.utils._coerce import _coerce_to_dict
2222

23-
NORMALIZER = Union[TorchNormalizer, EncoderNormalizer, NaNLabelEncoder]
23+
NORMALIZER = TorchNormalizer | EncoderNormalizer | NaNLabelEncoder
2424

2525

2626
class _TslibDataset(Dataset):
@@ -294,21 +294,21 @@ def __init__(
294294
freq: str = "h",
295295
add_relative_time_idx: bool = False,
296296
add_target_scales: bool = False,
297-
target_normalizer: Union[
298-
NORMALIZER, str, list[NORMALIZER], tuple[NORMALIZER], None
299-
] = "auto", # noqa: E501
300-
scalers: Optional[
301-
dict[
302-
str,
303-
Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer],
304-
]
305-
] = None, # noqa: E501
297+
target_normalizer: NORMALIZER
298+
| str
299+
| list[NORMALIZER]
300+
| tuple[NORMALIZER]
301+
| None = "auto", # noqa: E501
302+
scalers: dict[
303+
str, StandardScaler | RobustScaler | TorchNormalizer | EncoderNormalizer
304+
]
305+
| None = None, # noqa: E501
306306
shuffle: bool = True,
307307
window_stride: int = 1,
308308
batch_size: int = 32,
309309
num_workers: int = 0,
310310
train_val_test_split: tuple[float, float, float] = (0.7, 0.15, 0.15),
311-
collate_fn: Optional[callable] = None,
311+
collate_fn: callable | None = None,
312312
**kwargs,
313313
) -> None:
314314
super().__init__()
@@ -670,7 +670,7 @@ def _create_windows(self, indices: torch.Tensor) -> list[tuple[int, int, int, in
670670

671671
return windows
672672

673-
def setup(self, stage: Optional[str] = None) -> None:
673+
def setup(self, stage: str | None = None) -> None:
674674
"""
675675
Setup the data module by preparing the datasets for training,
676676
testing and validation.
@@ -879,7 +879,7 @@ def collate_fn(batch):
879879
[x["static_continuous_features"] for x, _ in batch]
880880
)
881881

882-
if isinstance(batch[0][1], (list, tuple)):
882+
if isinstance(batch[0][1], list | tuple):
883883
num_targets = len(batch[0][1])
884884
y_batch = []
885885
for i in range(num_targets):

pytorch_forecasting/data/data_module.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pytorch_forecasting.data.timeseries import TimeSeries
2424
from pytorch_forecasting.utils._coerce import _coerce_to_dict
2525

26-
NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer]
26+
NORMALIZER = TorchNormalizer | EncoderNormalizer | NaNLabelEncoder
2727

2828

2929
class EncoderDecoderTimeSeriesDataModule(LightningDataModule):
@@ -85,25 +85,25 @@ def __init__(
8585
self,
8686
time_series_dataset: TimeSeries,
8787
max_encoder_length: int = 30,
88-
min_encoder_length: Optional[int] = None,
88+
min_encoder_length: int | None = None,
8989
max_prediction_length: int = 1,
90-
min_prediction_length: Optional[int] = None,
91-
min_prediction_idx: Optional[int] = None,
90+
min_prediction_length: int | None = None,
91+
min_prediction_idx: int | None = None,
9292
allow_missing_timesteps: bool = False,
9393
add_relative_time_idx: bool = False,
9494
add_target_scales: bool = False,
95-
add_encoder_length: Union[bool, str] = "auto",
96-
target_normalizer: Union[
97-
NORMALIZER, str, list[NORMALIZER], tuple[NORMALIZER], None
98-
] = "auto",
99-
categorical_encoders: Optional[dict[str, NaNLabelEncoder]] = None,
100-
scalers: Optional[
101-
dict[
102-
str,
103-
Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer],
104-
]
105-
] = None,
106-
randomize_length: Union[None, tuple[float, float], bool] = False,
95+
add_encoder_length: bool | str = "auto",
96+
target_normalizer: NORMALIZER
97+
| str
98+
| list[NORMALIZER]
99+
| tuple[NORMALIZER]
100+
| None = "auto",
101+
categorical_encoders: dict[str, NaNLabelEncoder] | None = None,
102+
scalers: dict[
103+
str, StandardScaler | RobustScaler | TorchNormalizer | EncoderNormalizer
104+
]
105+
| None = None,
106+
randomize_length: None | tuple[float, float] | bool = False,
107107
batch_size: int = 32,
108108
num_workers: int = 0,
109109
train_val_test_split: tuple = (0.7, 0.15, 0.15),
@@ -623,7 +623,7 @@ def _create_windows(self, indices: torch.Tensor) -> list[tuple[int, int, int, in
623623

624624
return windows
625625

626-
def setup(self, stage: Optional[str] = None):
626+
def setup(self, stage: str | None = None):
627627
"""Prepare the datasets for training, validation, testing, or prediction.
628628
629629
Parameters
@@ -746,7 +746,7 @@ def collate_fn(batch):
746746
[x["static_continuous_features"] for x, _ in batch]
747747
)
748748

749-
if isinstance(batch[0][1], (list, tuple)):
749+
if isinstance(batch[0][1], list | tuple):
750750
num_targets = len(batch[0][1])
751751
y_batch = []
752752
for i in range(num_targets):

0 commit comments

Comments
 (0)