Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ignore=baselines,assets,checkpoints

# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=^\.|^_|^.*\.md|^.*\.txt|^.*\.CFF|^LICENSE
ignore-patterns=^\.|^_|^.*\.md|^.*\.txt|^.*\.csv|^.*\.CFF|^LICENSE

# Pickle collected data for later comparisons.
persistent=no
Expand Down
6 changes: 3 additions & 3 deletions basicts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .launcher import launch_evaluation, launch_training
from .launcher import launch_evaluation, launch_inference, launch_training
from .runners import BaseEpochRunner

__version__ = '0.4.6.3'
__version__ = '0.4.6.4'

__all__ = ['__version__', 'launch_training', 'launch_evaluation', 'BaseEpochRunner']
__all__ = ['__version__', 'launch_training', 'launch_evaluation', 'BaseEpochRunner', 'launch_inference']
214 changes: 214 additions & 0 deletions basicts/data/simple_inference_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import json
import logging
from typing import List, Tuple, Union

import numpy as np
import pandas as pd

from .base_dataset import BaseDataset


class TimeSeriesInferenceDataset(BaseDataset):
"""
A dataset class for time series inference tasks, where the input is a sequence of historical data points

Attributes:
description_file_path (str): Path to the JSON file containing the description of the dataset.
description (dict): Metadata about the dataset, such as shape and other properties.
data (np.ndarray): The loaded time series data array.
raw_data (str): The raw data path or data list of the dataset.
last_datetime (pd.Timestamp): The last datetime in the dataset. Used to generate time features of future data.
"""

def __init__(self, dataset_name:str, dataset: Union[str, list], input_len: int, output_len: int,
logger: logging.Logger = None, train_val_test_ratio: List[float] = None) -> None:
"""
Initializes the TimeSeriesInferenceDataset by setting up paths, loading data, and
preparing it according to the specified configurations.

Args:
dataset_name (str): The name of the dataset.
dataset(str or array): The data path of the dataset or data itself.
input_len(str): The length of the input sequence (number of historical points).
output_len(str): The length of the output sequence (number of future points to predict).
logger (logging.Logger): logger.
train_val_test_ratio (List[float]): The ratio of train, validation, and test data. Just for compatibility.
Raises:

"""
train_val_test_ratio: List[float] = []
mode: str = 'inference'
overlap = False
super().__init__(dataset_name, train_val_test_ratio, mode, input_len, output_len, overlap)
self.logger = logger

self.description_file_path = f'datasets/{dataset_name}/desc.json'
self.description = self._load_description()

self.last_datetime:pd.Timestamp = pd.Timestamp.now()
self._raw_data = dataset
self.data = self._load_data()

def _load_description(self) -> dict:
"""
Loads the description of the dataset from a JSON file.

Returns:
dict: A dictionary containing metadata about the dataset, such as its shape and other properties.

Raises:
FileNotFoundError: If the description file is not found.
json.JSONDecodeError: If there is an error decoding the JSON data.
"""
try:
with open(self.description_file_path, 'r') as f:
return json.load(f)
except FileNotFoundError as e:
raise FileNotFoundError(f'Description file not found: {self.description_file_path}') from e
except json.JSONDecodeError as e:
raise ValueError(f'Error decoding JSON file: {self.description_file_path}') from e

def _load_data(self) -> np.ndarray:
"""
Loads the time series data from a file or list and processes it according to the dataset description.
Returns:
np.ndarray: The data array for the specified mode (train, validation, or test).

Raises:
ValueError: If there is an issue with loading the data file or if the data shape is not as expected.
"""

if isinstance(self._raw_data, str):
df = pd.read_csv(self._raw_data, header=None)
else:
df = pd.DataFrame(self._raw_data)

df_index = pd.to_datetime(df[0].values, format='%Y-%m-%d %H:%M:%S').to_numpy()
df = df[df.columns[1:]]
df.index = pd.Index(df_index)
df = df.astype('float32')
self.last_datetime = df.index[-1]

data = np.expand_dims(df.values, axis=-1)
data = data[..., [0]]

data_with_features = self._add_temporal_features(data, df)

data_set_shape = self.description['shape']
_, n, c = data_with_features.shape
if data_set_shape[1] != n or data_set_shape[2] != c:
raise ValueError(f'Error loading data. Shape mismatch: expected {data_set_shape[1:]}, got {[n,c]}.')

return data_with_features

def _add_temporal_features(self, data, df) -> np.ndarray:
'''
Add time of day and day of week as features to the data.

Args:
data (np.ndarray): The data array.
df (pd.DataFrame): The dataframe containing the datetime index.

Returns:
np.ndarray: The data array with added time of day and day of week features.
'''

_, n, _ = data.shape
feature_list = [data]

# numerical time_of_day
tod = (df.index.hour*60 + df.index.minute) / (24*60)
tod_tiled = np.tile(tod, [1, n, 1]).transpose((2, 1, 0))
feature_list.append(tod_tiled)

# numerical day_of_week
dow = df.index.dayofweek / 7
dow_tiled = np.tile(dow, [1, n, 1]).transpose((2, 1, 0))
feature_list.append(dow_tiled)

# numerical day_of_month
dom = (df.index.day - 1) / 31 # df.index.day starts from 1. We need to minus 1 to make it start from 0.
dom_tiled = np.tile(dom, [1, n, 1]).transpose((2, 1, 0))
feature_list.append(dom_tiled)

# numerical day_of_year
doy = (df.index.dayofyear - 1) / 366 # df.index.month starts from 1. We need to minus 1 to make it start from 0.
doy_tiled = np.tile(doy, [1, n, 1]).transpose((2, 1, 0))
feature_list.append(doy_tiled)

data_with_features = np.concatenate(feature_list, axis=-1).astype('float32') # L x N x C

# Remove extra features
data_set_shape = self.description['shape']
data_with_features = data_with_features[..., range(data_set_shape[2])]

return data_with_features

def append_data(self, new_data: np.ndarray) -> None:
"""
Append new data to the existing data

Args:
new_data (np.ndarray): The new data to append to the existing data.
"""

freq = self.description['frequency (minutes)']
l, _, _ = new_data.shape

data_with_features, datetime_list = self._gen_datetime_list(new_data, self.last_datetime, freq, l)
self.last_datetime = datetime_list[-1]

self.data = np.concatenate([self.data, data_with_features], axis=0)

def _gen_datetime_list(self, new_data: np.ndarray, start_datetime: pd.Timestamp, freq: int, num_steps: int) -> Tuple[np.ndarray, List[pd.Timestamp]]:
"""
Generate a list of datetime objects based on the start datetime, frequency, and number of steps.

Args:
start_datetime (pd.Timestamp): The starting datetime for the sequence.
freq (int): The frequency of the data in minutes.
num_steps (int): The number of steps in the sequence.

Returns:
List[pd.Timestamp]: A list of datetime objects corresponding to the sequence.
"""
datetime_list = [start_datetime]
for _ in range(num_steps):
datetime_list.append(datetime_list[-1] + pd.Timedelta(minutes=freq))
new_index = pd.Index(datetime_list[1:])
new_df = pd.DataFrame()
new_df.index = new_index
data_with_features = self._add_temporal_features(new_data, new_df)

return data_with_features, datetime_list

def __getitem__(self, index: int) -> dict:
"""
Retrieves a sample from the dataset, considering both the input and output lengths.
For inference, the input data is the last 'input_len' points in the dataset, and the output data is the next 'output_len' points.

Args:
index (int): The index of the desired sample in the dataset.

Returns:
dict: A dictionary containing 'inputs' and 'target', where both are slices of the dataset corresponding to
the historical input data and future prediction data, respectively.
"""
history_data = self.data[-self.input_len:]

freq = self.description['frequency (minutes)']
_, n, _ = history_data.shape
future_data = np.zeros((self.output_len, n, 1))

data_with_features, _ = self._gen_datetime_list(future_data, self.last_datetime, freq, self.output_len)
return {'inputs': history_data, 'target': data_with_features}

def __len__(self) -> int:
"""
Calculates the total number of samples available in the dataset.
For inference, there is only one valid sample, as the input data is the last 'input_len' points in the dataset.

Returns:
int: The number of valid samples that can be drawn from the dataset, based on the configurations of input and output lengths.
"""
return 1
101 changes: 101 additions & 0 deletions basicts/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,104 @@ def launch_training(cfg: Union[Dict, str],

# launch the training process
easytorch.launch_training(cfg=cfg, devices=gpus, node_rank=node_rank)

def inference_func(cfg: Dict,
input_data_file_path: str,
output_data_file_path: str,
ckpt_path: str,
strict: bool = True) -> None:
"""
Starts the inference process.

This function performs the following steps:
1. Initializes the runner specified in the configuration (`cfg`).
2. Sets up logging for the inference process.
3. Loads the model checkpoint.
4. Executes the inference pipeline using the initialized runner.

Args:
cfg (Dict): EasyTorch configuration dictionary.
input_data_file_path (str): Path to the input data file.
output_data_file_path (str): Path to the output data file.
ckpt_path (str): Path to the model checkpoint. If not provided, the best model checkpoint is loaded automatically.
strict (bool): Enforces that the checkpoint keys match the model. Defaults to True.

Raises:
Exception: Catches any exception, logs the traceback, and re-raises it.
"""

# initialize the runner
logger = get_logger('easytorch-launcher')
logger.info(f"Initializing runner '{cfg['RUNNER']}'")
runner = cfg['RUNNER'](cfg)

# initialize the logger for the runner
runner.init_logger(logger_name='easytorch-inference', log_file_name='inference_log')

# setup the graph if needed
if runner.need_setup_graph:
runner.setup_graph(cfg=cfg, train=False)

try:
# load the model checkpoint
if ckpt_path is None or not os.path.exists(ckpt_path):
ckpt_path_auto = os.path.join(runner.ckpt_save_dir, '{}_best_val_{}.pt'.format(runner.model_name, runner.target_metrics.replace('/', '_')))
logger.info(f'Checkpoint file not found at {ckpt_path}. Loading the best model checkpoint `{ckpt_path_auto}` automatically.')
if not os.path.exists(ckpt_path_auto):
raise FileNotFoundError(f'Checkpoint file not found at {ckpt_path}')
runner.load_model(ckpt_path=ckpt_path_auto, strict=strict)
else:
logger.info(f'Loading model checkpoint from {ckpt_path}')
runner.load_model(ckpt_path=ckpt_path, strict=strict)

# start the inference pipeline
runner.inference_pipeline(cfg=cfg, input_data=input_data_file_path, output_data_file_path=output_data_file_path)

except BaseException as e:
# log the exception and re-raise it
runner.logger.error(traceback.format_exc())
raise e

def launch_inference(cfg: Union[Dict, str],
ckpt_path: str,
input_data_file_path: str,
output_data_file_path: str,
device_type: str = 'gpu',
gpus: Optional[str] = None) -> None:
"""
Launches the inference process.

Args:
cfg (Union[Dict, str]): EasyTorch configuration as a dictionary or a path to a config file.
ckpt_path (str): Path to the model checkpoint.
input_data_file_path (str): Path to the input data file.
output_data_file_path (str): Path to the output data file.
device_type (str, optional): Device type to use ('cpu' or 'gpu'). Defaults to 'gpu'.
gpus (Optional[str]): GPU device IDs to use. Defaults to None (use all available GPUs).

Raises:
AssertionError: If the batch size is not specified in either the config or as an argument.
"""

logger = get_logger('easytorch-launcher')
logger.info('Launching EasyTorch inference.')

# check params
# cfg path which start with dot will crash the easytorch, just remove dot
while isinstance(cfg, str) and cfg.startswith(('./','.\\')):
cfg = cfg[2:]
while ckpt_path.startswith(('./','.\\')):
ckpt_path = ckpt_path[2:]

# initialize the configuration
cfg_dict = init_cfg(cfg, save=True)

# set the device type (CPU, GPU, or MLU)
set_device_type(device_type)

# set the visible GPUs if the device type is not CPU
if device_type != 'cpu':
set_visible_devices(gpus)

# run the inference process
inference_func(cfg_dict, input_data_file_path, output_data_file_path, ckpt_path)
Loading
Loading