Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6765337
isort lint
duyifanict Nov 11, 2025
61ad2bc
chore: 🤖 fix isort.cfg & add release branch to test
duyifanict Nov 12, 2025
57b7a4e
fix: 🐛 fix a bug in Autoformer
yisongfu Nov 13, 2025
ca15366
koopa的一次尝试
Nov 17, 2025
0bd67c1
koopa的一次尝试
Nov 17, 2025
64549ac
改改格式
Nov 17, 2025
2449e8a
feat: 🎸 upgrade FiLM and TiDE
ChengqingYu Nov 21, 2025
c117a00
Merge pull request #293 from wgawmy/beta
yisongfu Dec 2, 2025
8f052ca
tests: 📏 fix smoke tests for koopa
duyifanict Dec 2, 2025
fbe7d62
fix isort for koopa
duyifanict Dec 2, 2025
59d48b9
tests: 📏 add now smoke tests
duyifanict Dec 5, 2025
eac4c97
Merge pull request #301 from duyifanict/beta
yisongfu Dec 5, 2025
02bcf48
Merge remote-tracking branch 'origin/beta' into beta
ChengqingYu Dec 5, 2025
a8c7348
fix: 🐛 fix a bug in BLAST
ChengqingYu Dec 5, 2025
4581fd7
fix: 🐛 fix a bug in Autoformer
ChengqingYu Dec 5, 2025
9215117
fix: 🐛 fix a bug in SOFTS
ChengqingYu Dec 5, 2025
741cafa
fix: 🐛 fix a bug in TimeXer
ChengqingYu Dec 5, 2025
6d3e78b
fix: 🐛 fix a bug in NSformer
ChengqingYu Dec 5, 2025
29ab595
fix: 🐛 fix a bug in PatchTST
ChengqingYu Dec 5, 2025
35d1c9b
fix: 🐛 fix a bug in TimeMixer
ChengqingYu Dec 5, 2025
0c18eb0
fix: 🐛 fix bugs in models
ChengqingYu Dec 5, 2025
fd0e65f
lint
duyifanict Dec 8, 2025
f259698
chore: 🤖 add action for auto release and pypi-update (#304)
duyifanict Dec 15, 2025
ff4113a
chore: 🤖 move koopa callback to model directory
yisongfu Dec 18, 2025
b7a2e5b
update selective learning
yisongfu Dec 18, 2025
870c585
fix: 🐛 fix a bug that config may be mistakenly overwritten by shortcu…
yisongfu Dec 18, 2025
53e5135
update base config
yisongfu Dec 18, 2025
ea5b170
update to version 1.1
yisongfu Dec 18, 2025
55d5946
fix lint
yisongfu Dec 19, 2025
76b4891
fix: 🐛 fix a bug in DDP training
yisongfu Dec 19, 2025
2cf1ebf
update selective learning to support load configuration from json
yisongfu Dec 19, 2025
e02cb12
build: add distribution files
yisongfu Dec 19, 2025
c2edee4
fix smoke test for koopa
duyifanict Dec 19, 2025
a360ac1
Merge branch 'beta' into release/v1.1.0
duyifanict Dec 19, 2025
f7f8f15
build: add distribution files
duyifanict Dec 19, 2025
260635f
Enable write permissions for GITHUB_TOKEN
duyifanict Dec 19, 2025
ab35018
build: add distribution files
duyifanict Dec 19, 2025
9795acc
delete dist
duyifanict Dec 23, 2025
b6506e2
Merge branch 'beta'
duyifanict Dec 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: Pytest

on:
push:
branches: [ "master" ]
branches: [ "master", "release/*" ]
pull_request:
branches: [ "master" ]
branches: [ "master", "release/*" ]

permissions:
contents: read
Expand All @@ -29,6 +29,7 @@ jobs:
pip install flake8 pytest
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f tests/requirements.txt ]; then pip install -r tests/requirements.txt; fi
- name: Test with pytest
run: |
pytest
4 changes: 2 additions & 2 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[settings]
src_paths=basicts,tests
skip_glob=baselines/*,assets/*,examples/*
src_paths=src/basicts,tests
skip_glob=baselines/*,assets/*
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ignore=baselines,assets,checkpoints,examples,scripts

# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=^\.|^_|^.*\.md|^.*\.txt|^.*\.csv|^.*\.CFF|^LICENSE
ignore-patterns=^\.|^_|^.*\.md|^.*\.txt|^.*\.csv|^.*\.CFF|^LICENSE|^.*\.toml

# Pickle collected data for later comparisons.
persistent=no
Expand Down
3 changes: 2 additions & 1 deletion examples/classification/classification_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from basicts import BasicTSLauncher
from basicts.configs import BasicTSClassificationConfig
from basicts.models.iTransformer import iTransformerForClassification, iTransformerConfig
from basicts.models.iTransformer import (iTransformerConfig,
iTransformerForClassification)


def main():
Expand Down
8 changes: 5 additions & 3 deletions examples/forecasting/forecasting_demo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from torch.optim.lr_scheduler import MultiStepLR

from basicts import BasicTSLauncher
from basicts.configs import BasicTSForecastingConfig
from basicts.models.iTransformer import iTransformerForForecasting, iTransformerConfig
from basicts.runners.callback import EarlyStopping, GradientClipping
from basicts.metrics import masked_mse
from torch.optim.lr_scheduler import MultiStepLR
from basicts.models.iTransformer import (iTransformerConfig,
iTransformerForForecasting)
from basicts.runners.callback import EarlyStopping, GradientClipping


def main():
Expand Down
3 changes: 2 additions & 1 deletion examples/imputation/imputation_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from basicts import BasicTSLauncher
from basicts.configs import BasicTSImputationConfig
from basicts.models.iTransformer import iTransformerForReconstruction, iTransformerConfig
from basicts.models.iTransformer import (iTransformerConfig,
iTransformerForReconstruction)


def main():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"sympy",
"openpyxl",
"setuptools==59.5.0",
"numpy==1.24.4",
"numpy",
"tqdm==4.67.1",
"tensorboard==2.18.0",
"transformers==4.40.1"
Expand Down
2 changes: 1 addition & 1 deletion src/basicts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .launcher import BasicTSLauncher

__version__ = '1.0.2'
__version__ = '1.1.0'

__all__ = ['__version__', 'BasicTSLauncher']
19 changes: 12 additions & 7 deletions src/basicts/configs/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,22 @@
from functools import partial
from numbers import Number
from types import FunctionType
from typing import Callable, List, Literal, Optional, Tuple, Union
from typing import (TYPE_CHECKING, Callable, List, Literal, Optional, Tuple,
Union)

import numpy as np
import torch
from basicts.runners.callback import BasicTSCallback
from basicts.runners.taskflow import BasicTSTaskFlow
from easydict import EasyDict
from torch.optim.lr_scheduler import LRScheduler

from .model_config import BasicTSModelConfig

# avoid circular imports
if TYPE_CHECKING:
from basicts.runners.callback import BasicTSCallback
from basicts.runners.taskflow import BasicTSTaskFlow



@dataclass(init=False)
class BasicTSConfig(EasyDict):
Expand All @@ -35,8 +40,8 @@ class BasicTSConfig(EasyDict):
model_config: BasicTSModelConfig

dataset_name: str
taskflow: BasicTSTaskFlow
callbacks: List[BasicTSCallback]
taskflow: "BasicTSTaskFlow"
callbacks: List["BasicTSCallback"]

############################## General Configuration ##############################

Expand Down Expand Up @@ -277,7 +282,7 @@ def _pack_params(self, obj: type, obj_params: Union[dict, None]) -> dict:
elif issubclass(obj, LRScheduler) and k == "optimizer":
continue
# short cut has higher priority than params in config
elif k in self:
elif k in self and self[k] is not None:
obj_params[k] = self[k]
return obj_params

Expand Down Expand Up @@ -338,7 +343,7 @@ def _serialize_obj(self, obj: object) -> object:
if not isinstance(is_default, bool):
raise ValueError(f"Parameter {k} of {obj.__class__.__name__} is not serializable.")
if not is_default:
params[k] = repr(v)
params[k] = self._serialize_obj(v)

return {
"name": obj.__class__.__name__,
Expand Down
13 changes: 8 additions & 5 deletions src/basicts/configs/tsc_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from typing import Callable, List, Literal, Tuple, Union

import numpy as np
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

from basicts.data import UEADataset
from basicts.runners.callback import BasicTSCallback
from basicts.runners.taskflow import (BasicTSClassificationTaskFlow,
BasicTSTaskFlow)
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

from .base_config import BasicTSConfig
from .model_config import BasicTSModelConfig
Expand Down Expand Up @@ -99,9 +100,11 @@ class BasicTSClassificationConfig(BasicTSConfig):

# Dataset settings
dataset_type: type = field(default=UEADataset, metadata={"help": "Dataset type."})
dataset_params: Union[dict, None] = field(default=None, metadata={"help": "Dataset parameters."})
dataset_params: Union[dict, None] = field(
default_factory=lambda: {"memmap": False},
metadata={"help": "Dataset parameters."})
use_timestamps: bool = field(default=False, metadata={"help": "Whether to use timestamps as supplementary."})
memmap: bool = field(default=False, metadata={"help": "Whether to use memmap to load datasets."})
memmap: bool = field(default=None, metadata={"help": "Whether to use memmap to load datasets."})
null_val: float = field(default=np.nan, metadata={"help": "Null value."})
null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."})

Expand Down Expand Up @@ -148,7 +151,7 @@ class BasicTSClassificationConfig(BasicTSConfig):
optimizer_params: dict = field(
default_factory=lambda: {"lr": 2e-4, "weight_decay": 5e-4},
metadata={"help": "Optimizer parameters."})
lr: float = field(default=2e-4, metadata={"help": "Learning rate."})
lr: float = field(default=None, metadata={"help": "Learning rate."})

# Learning rate scheduler
lr_scheduler: Union[type, None] = field(default=None, metadata={"help": "Learning rate scheduler type."})
Expand Down
30 changes: 20 additions & 10 deletions src/basicts/configs/tsf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from typing import Callable, List, Literal, Tuple, Union

import numpy as np
from torch.optim import Adam

from basicts.data import BasicTSForecastingDataset
from basicts.runners.callback import BasicTSCallback
from basicts.runners.taskflow import (BasicTSForecastingTaskFlow,
BasicTSTaskFlow)
from basicts.scaler import ZScoreScaler
from torch.optim import Adam

from .base_config import BasicTSConfig
from .model_config import BasicTSModelConfig
Expand Down Expand Up @@ -99,17 +100,26 @@ class BasicTSForecastingConfig(BasicTSConfig):

# Dataset settings
dataset_type: type = field(default=BasicTSForecastingDataset, metadata={"help": "Dataset type."})
dataset_params: Union[dict, None] = field(default=None, metadata={"help": "Dataset parameters."})
input_len: int = field(default=336, metadata={"help": "Input length."})
output_len: int = field(default=336, metadata={"help": "Output length."})
use_timestamps: bool = field(default=True, metadata={"help": "Whether to use timestamps as supplementary."})
memmap: bool = field(default=False, metadata={"help": "Whether to use memmap to load datasets."})
null_val: float = field(default=np.nan, metadata={"help": "Null value."})
null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."})

dataset_params: Union[dict, None] = field(
default_factory=lambda: {
"input_len": 336,
"output_len": 336,
"use_timestamps": True,
"memmap": False,
}, metadata={"help": "Dataset parameters."})

# shortcuts
input_len: int = field(default=None, metadata={"help": "Input length."})
output_len: int = field(default=None, metadata={"help": "Output length."})
use_timestamps: bool = field(default=None, metadata={"help": "Whether to use timestamps as supplementary."})
memmap: bool = field(default=None, metadata={"help": "Whether to use memmap to load datasets."})
batch_size: Union[int, None] = field(
default=None, metadata={"help": "Batch size. If setted, all dataloaders will be setted to the same batch size."})


null_val: float = field(default=np.nan, metadata={"help": "Null value."})
null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."})

# Scaler settings
scaler: type = field(default=ZScoreScaler, metadata={"help": "Scaler type."})
norm_each_channel: bool = field(default=True, metadata={"help": "Whether to normalize data for each channel independently."})
Expand Down Expand Up @@ -147,7 +157,7 @@ class BasicTSForecastingConfig(BasicTSConfig):
optimizer_params: dict = field(
default_factory=lambda: {"lr": 2e-4, "weight_decay": 5e-4},
metadata={"help": "Optimizer parameters."})
lr: float = field(default=2e-4, metadata={"help": "Learning rate."})
lr: float = field(default=None, metadata={"help": "Learning rate."})

# Learning rate scheduler
lr_scheduler: Union[type, None] = field(default=None, metadata={"help": "Learning rate scheduler type."})
Expand Down
13 changes: 7 additions & 6 deletions src/basicts/configs/tsfm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import numpy as np
import torch
from torch.optim import AdamW

from basicts.data import BasicTSForecastingDataset
from basicts.runners.callback import BasicTSCallback
from basicts.runners.optim.lr_schedulers import CosineWarmup
from basicts.runners.taskflow import (BasicTSForecastingTaskFlow,
BasicTSTaskFlow)
from torch.optim import AdamW

from .base_config import BasicTSConfig
from .model_config import BasicTSModelConfig
Expand Down Expand Up @@ -100,10 +101,10 @@ class BasicTSFoundationModelConfig(BasicTSConfig):
# Dataset settings
dataset_type: type = field(default=BasicTSForecastingDataset, metadata={"help": "Dataset type."})
dataset_params: dict = field(default_factory=dict)
input_len: int = field(default=336, metadata={"help": "Input length."})
output_len: int = field(default=336, metadata={"help": "Output length."})
use_timestamps: bool = field(default=False, metadata={"help": "Whether to use timestamps as supplementary."})
memmap: bool = field(default=False, metadata={"help": "Whether to use memmap to load datasets."})
input_len: int = field(default=None, metadata={"help": "Input length."})
output_len: int = field(default=None, metadata={"help": "Output length."})
use_timestamps: bool = field(default=None, metadata={"help": "Whether to use timestamps as supplementary."})
memmap: bool = field(default=None, metadata={"help": "Whether to use memmap to load datasets."})
batch_size: Optional[int] = field(default=None, metadata={"help": "Batch size. If setted, all dataloaders will be setted to the same batch size."})
null_val: float = field(default=np.nan, metadata={"help": "Null value."})
null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."})
Expand Down Expand Up @@ -142,7 +143,7 @@ class BasicTSFoundationModelConfig(BasicTSConfig):
# Optimizer
optimizer: type = field(default=AdamW)
optimizer_params: dict = field(default_factory=lambda: {"lr": 1e-3, "fused": True})
lr: float = field(default=1e-3, metadata={"help": "Learning rate."})
lr: float = field(default=None, metadata={"help": "Learning rate."})

# Learning rate scheduler
lr_scheduler: type = field(default=CosineWarmup)
Expand Down
26 changes: 17 additions & 9 deletions src/basicts/configs/tsi_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Callable, List, Literal, Tuple, Union

import numpy as np
from torch.optim import Adam

from basicts.data import BasicTSImputationDataset
from basicts.runners.callback import BasicTSCallback
from basicts.runners.taskflow import BasicTSImputationTaskFlow, BasicTSTaskFlow
from basicts.scaler import ZScoreScaler
from torch.optim import Adam

from .base_config import BasicTSConfig
from .model_config import BasicTSModelConfig
Expand Down Expand Up @@ -98,17 +99,24 @@ class BasicTSImputationConfig(BasicTSConfig):

# Dataset settings
dataset_type: type = field(default=BasicTSImputationDataset, metadata={"help": "Dataset type."})
dataset_params: Union[dict, None] = field(default=None, metadata={"help": "Dataset parameters."})
input_len: int = field(default=336, metadata={"help": "Input length."})
dataset_params: Union[dict, None] = field(
default_factory=lambda: {
"input_len": 336,
"use_timestamps": True,
"memmap": False,
}, metadata={"help": "Dataset parameters."})

# shortcuts
input_len: int = field(default=None, metadata={"help": "Input length."})
use_timestamps: bool = field(default=None, metadata={"help": "Whether to use timestamps as supplementary."})
memmap: bool = field(default=None, metadata={"help": "Whether to use memmap to load datasets."})
batch_size: Union[int, None] = field(
default=None, metadata={"help": "Batch size. If setted, all dataloaders will be setted to the same batch size."})

mask_ratio: float = field(default=0.25, metadata={"help": "Mask ratio."})
use_timestamps: bool = field(default=True, metadata={"help": "Whether to use timestamps as supplementary."})
memmap: bool = field(default=False, metadata={"help": "Whether to use memmap to load datasets."})
null_val: float = field(default=np.nan, metadata={"help": "Null value."})
null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."})

batch_size: Union[int, None] = field(
default=None, metadata={"help": "Batch size. If setted, all dataloaders will be setted to the same batch size."})

# Scaler settings
scaler: type = field(default=ZScoreScaler, metadata={"help": "Scaler type."})
norm_each_channel: bool = field(default=True, metadata={"help": "Whether to normalize data for each channel independently."})
Expand Down Expand Up @@ -146,7 +154,7 @@ class BasicTSImputationConfig(BasicTSConfig):
optimizer_params: dict = field(
default_factory=lambda: {"lr": 2e-4, "weight_decay": 5e-4},
metadata={"help": "Optimizer parameters."})
lr: float = field(default=2e-4, metadata={"help": "Learning rate."})
lr: float = field(default=None, metadata={"help": "Learning rate."})

# Learning rate scheduler
lr_scheduler: Union[type, None] = field(default=None, metadata={"help": "Learning rate scheduler type."})
Expand Down
3 changes: 2 additions & 1 deletion src/basicts/data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Union

import numpy as np
from basicts.utils.constants import BasicTSMode
from torch.utils.data import Dataset

from basicts.utils.constants import BasicTSMode


class BasicTSDataset(Dataset):
"""
Expand Down
12 changes: 11 additions & 1 deletion src/basicts/data/blast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Union

import numpy as np

from basicts.utils.constants import BasicTSMode

from .base_dataset import BasicTSDataset
Expand Down Expand Up @@ -56,7 +57,7 @@ class BLAST(BasicTSDataset):

def __post_init__(self):
# load data
self.data = self._load_data()
self._data = self._load_data()
self.output_len = self.output_len or 0

# minimum valid history sequence length
Expand Down Expand Up @@ -244,6 +245,15 @@ def __getitem__(self, idx: int) -> tuple:
def __len__(self):
return self.data.shape[0]

def __getstate__(self):
state = self.__dict__.copy()
del state["_data"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._data = self._load_data()

@property
def data(self) -> np.ndarray:
return self._data
1 change: 1 addition & 0 deletions src/basicts/data/tsf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union

import numpy as np

from basicts.utils.constants import BasicTSMode

from .base_dataset import BasicTSDataset
Expand Down
1 change: 1 addition & 0 deletions src/basicts/data/tsi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union

import numpy as np

from basicts.utils.constants import BasicTSMode

from .base_dataset import BasicTSDataset
Expand Down
Loading