Coder Social home page Coder Social logo

cyria7 / tfb Goto Github PK

View Code? Open in Web Editor NEW

This project forked from decisionintelligence/tfb

0.0 0.0 0.0 1.87 MB

About Code release for "TFB: Towards Comprehensive and Fair Benchmarking of Time Series Forecasting Methods" (PVLDB 2024)

Shell 60.27% Python 39.73%

tfb's Introduction

Time seires Forecasting Benchmark (TFB)

TFB is an open-source library designed for time series researchers.

We provide a clean codebase for end-to-end evaluation of time series forecasting models, comparing their performance with baseline algorithms under various evaluation strategies and metrics.

Quickstart

Installation

Given a python environment (note: this project is fully tested under python 3.8), install the dependencies with the following command:

pip install -r requirements.txt

Data preparation

Prepare Data. You can obtained the well pre-processed datasets from Google Drive.Then place the downloaded data under the folder ./dataset.

Train and evaluate model.

We provide the experiment scripts for all benchmarks under the folder ./scripts/multivariate_forecast and ./scripts/univariate_forecast. For example,you can reproduce a experiment result as the following:

sh ./scripts/multivariate_forecast/AQShunyi_script/Triformer.sh

Steps to develop your own method

  1. Define you model or adapter class
  • The user-implemented model or adapter class should implement the following functions in order to adapt to this benchmark.

  • required_hyper_params function is optional,repr functions is necessary.

  • The function prototype is as follows:

    • required_hyper_params function:

      """
      Return the hyperparameters required by the model
      This function is optional and static
      
      :return: A dictionary that represents the hyperparameters required by the model
      :rtype: dict
      """
      
      # For example
      @staticmethod
      def required_hyper_params() -> dict:
          """
          An empty dictionary indicating that model does not require 		                 additional hyperparameters.
          """
          return {}
    • forecast_fit function training model

      """
      Train the model.
      
      :param train_valid_data: Time series data used for training.
      :param train_val_ratio: Represents the splitting ratio of the training         set validation set. If it is equal to 1, it means that the validation           set is not partitioned.
      """
      # For example
      def forecast_fit(self, train_valid_data: pd.DataFrame, train_val_ratio: float):
      		pass
    • forecast function utilizing the model for inference

      """
      Use models for forecasting
      
      :param pred_len: Predict length
      :type pred_len: int
      :param train: Training data used to fit the model
      :type train: pd.DataFrame
      
      :return: Forecasting results
      :rtype: np.ndarray
      """
      # For example
      def forecast(self, pred_len: int, train: pd.DataFrame) -> np.ndarray:
          pass
    • __repr __ string representation of function model name

      """
      Returns a string representation of the model name
      
      :return: Returns a string representation of the model name
      :rtype: str
      """
      # For example
      def __repr__(self) -> str:
          return self.model_name
  1. Configure your Configuration File
  • modify the corresponding config under the folder ./ts_benchmark/config/.

  • modify the contents in ./scripts/run_benchmark.py/.

  • We strongly recommend using the pre-defined configurations in ./ts_benchmark/config/. Create your own configuration file only when you have a clear understanding of the configuration items.

  1. The benchmark can be run in the following format:
python ./scripts/run_benchmark.py --config-path "rolling_forecast_config.json" --data-name-list "ETTh1.csv" --strategy-args '{"pred_len":96}' --model-name "time_series_library.Triformer" --model-hyper-params '{"d_ff": 64, "d_model": 32, "pred_len": 96, "seq_len": 96}' --adapter "transformer_adapter"  --gpus 0  --num-workers 1  --timeout 60000  --save-path "ETTh1/Triformer"

Example Usage

  • Define the model class or factory
    • We demonstrated what functions need to be implemented for time series forecasting using the VAR algorithm. You can find the complete code in ./ts_benchmark/baselines/self_implementation/VAR/VAR.py.
class VAR_model:
    """
    VAR class.

    This class encapsulates a process of using VAR models for time series forecasting.
    """

    def __init__(self):
        self.scaler = StandardScaler()
        self.model_args = {}

    @staticmethod
    def required_hyper_params() -> dict:
        """
        Return the hyperparameters required by VAR.

        :return: An empty dictionary indicating that VAR does not require additional hyperparameters.
        """
        
        return {}

    def forecast_fit(self, train_data: pd.DataFrame, train_val_ratio: float):
        """
        Train the model.

        :param train_data: Time series data used for training.
        :param train_val_ratio: Represents the splitting ratio of the training set validation set. If it is equal to 1, it means 					that the validation set is not partitioned.
        :return: The fitted model object.
        """

        self.scaler.fit(train_data.values)
        train_data_value = pd.DataFrame(
            self.scaler.transform(train_data.values),
            columns=train_data.columns,
            index=train_data.index,
        )
        model = VAR(train_data_value)
        self.results = model.fit(13)

    def forecast(self, pred_len: int, testdata: pd.DataFrame) -> np.ndarray:
        """
        Make predictions.

        :param pred_len: The predicted length.
        :param testdata: Time series data used for prediction.
        :return: An array of predicted results.
        """
        
        train = pd.DataFrame(
            self.scaler.transform(testdata.values),
            columns=testdata.columns,
            index=testdata.index,
        )
        z = self.results.forecast(train.values, steps=pred_len)

        predict = self.scaler.inverse_transform(z)
        return predict

    def __repr__(self) -> str:
        """
        Returns a string representation of the model name.
        """
        
        return self.model_name
  • Run benchmark using VAR

    python ./scripts/run_benchmark.py --config-path "rolling_forecast_config.json" --data-name-list "ETTh1.csv" --strategy-args '{"pred_len":96}' --model-name "self_implementation.VAR_model" --gpus 0  --num-workers 1  --timeout 60000  --save-path "ETTh1/VAR_model"

Citation

If you find this repo useful, please cite our paper.


Contact

If you have any questions or suggestions, feel free to contact:

Or describe it in Issues.

tfb's People

Contributors

qiu69 avatar ccloud0525 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.