Coder Social home page Coder Social logo

remigenet / tkat Goto Github PK

View Code? Open in Web Editor NEW
29.0 2.0 5.0 208 KB

Temporal Kolmogorov-Arnold Transformer

License: Other

Jupyter Notebook 87.61% Python 12.39%
temporal-networks tensorflow timeseries timeseries-forecasting tkan transformer tkat jax keras keras3 torch

tkat's Introduction

Temporal Kolmogorov-Arnold Transformer for Time Series Forecasting

TKAT representation

This folder includes the original code implemented for the paper of the same name. The model is made in keras3 and is supporting all backend (jax, tensorflow, pytorch).

It is inspired on the Temporal Fusion Transformer by google-research and the Temporal Kolmogorov Arnold Network.

The Temporal Kolmogorov-Arnold Transformer uses the TKAN layers from the paper to improve the performance of the Temporal Fusion Transformer by replacing the internal LSTM encoder and decoder part. It needs the implementation available here tkan with version >= 0.2.

The TKAT is however different from the Temporal Fusion Transformer on many aspects like the absence of static inputs and a different architecture after the multihead.

Installation

A Pypi package is available for the TKAT implementation. You can install it directly from PyPI:

pip install tkat

or can be installed by cloning the repo and using:

pip install path/to/tkat

Usage

Contrary to the TKAN package, the TKAT is a full model implementation and thus can be used directly as a model. Here is an example of how to use it:

from tkat import TKAT

N_MAX_EPOCHS = 100
BATCH_SIZE = 128
early_stopping_callback = lambda : tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.00001,
    patience=6,
    mode="min",
    restore_best_weights=True,
    start_from_epoch=6,
)
lr_callback = lambda : tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.25,
    patience=3,
    mode="min",
    min_delta=0.00001,
    min_lr=0.000025,
    verbose=0,
)
callbacks = lambda : [early_stopping_callback(), lr_callback(), tf.keras.callbacks.TerminateOnNaN()]


sequence_length = 30
num_unknow_features = 8
num_know_features = 2
num_embedding = 1
num_hidden = 100
num_heads = 4
use_tkan = True

model = TKAT(sequence_length, num_unknow_features, num_know_features, num_embedding, num_hidden, num_heads, n_ahead, use_tkan = use_tkan)
optimizer = tf.keras.optimizers.Adam(0.001)
model.compile(optimizer=optimizer, loss='mean_squared_error')

model.summary()

history = model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=N_MAX_EPOCHS, validation_split=0.2, callbacks=callbacks(), shuffle=True, verbose = False)

test_preds = model.predict(X_test)

X_train should be a numpy array of shape (n_samples, sequence_length + n_ahead, num_unknow_features + num_know_features) and y_train should be a numpy array of shape (n_samples, n_ahead). The values in X_train[:,sequence_length:,:num_unknow_features] are not used and can be set to 0. The known inputs should be the last features in X_train.

For a more detailed example please look to the notebook in the example folder.

Please cite our work if you use this repo:

@article{genet2024temporal,
  title={A Temporal Kolmogorov-Arnold Transformer for Time Series Forecasting},
  author={Genet, Remi and Inzirillo, Hugo},
  journal={arXiv preprint arXiv:2406.02486},
  year={2024}
}

Shield: CC BY-NC-SA 4.0

This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

CC BY-NC-SA 4.0

tkat's People

Contributors

remigenet avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

tkat's Issues

Loss stops changing with custom class

I made a custom class for classification instead of forecasting, taking out the decoder. I also added static features (including categorical).

Whenever I run this model, at some epoch (depending on optimizer, etc.) the validation loss stays the same, and the next epoch neither validation nor training loss change. Below is an extreme example where this immediately (sometimes it takes 2-3 epochs):

Epoch 1/100
135/135 - 95s - 703ms/step - f1_score: 0.0275 - loss: 12.8403 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 0.0010
Epoch 2/100
135/135 - 27s - 200ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 0.0010
Epoch 3/100
135/135 - 26s - 195ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 0.0010
Epoch 4/100
135/135 - 26s - 193ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 0.0010
Epoch 5/100
135/135 - 26s - 192ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 2.5000e-04
Epoch 6/100
135/135 - 26s - 189ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 2.5000e-04
Epoch 7/100
135/135 - 26s - 193ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 2.5000e-04
Epoch 8/100
135/135 - 25s - 188ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 6.2500e-05
Epoch 9/100
135/135 - 26s - 192ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 6.2500e-05
Epoch 10/100
135/135 - 27s - 196ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 6.2500e-05
Epoch 11/100
135/135 - 26s - 191ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 2.5000e-05
Epoch 12/100
135/135 - 25s - 189ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 2.5000e-05
Epoch 13/100
135/135 - 26s - 196ms/step - f1_score: 0.0214 - loss: 13.8861 - val_f1_score: 0.0201 - val_loss: 14.1408 - learning_rate: 2.5000e-05
10/11 ━━━━━━━━━━━━━━━━━━━━ 0s 94ms/step

This is my classifier:

def TKAT_classify(X: pd.DataFrame, num_embedding: int, num_hidden: int, num_heads: int, n_classes: int,
                  use_tkan: bool = True, filters=32, strides=16):
    X = X.copy()

    cat_cols = X.columns[X.dtypes == 'category']

    # X has dynamic (time series), numerical (static), and categorical (static) features.
    # Get the embedding size for the categorical features
    num_dynamic_features = X[X.columns[X.dtypes == object]].shape[1]
    num_static_features = X[X.columns[(X.dtypes != object) & (X.dtypes != "category")]].shape[1]
    num_categorical_features = X[cat_cols].shape[1]

    # assert num_dynamic_features == 18 and num_static_features == 5 and num_categorical_features == 3

    cat_embed_dict = {col: build_width(X[col].nunique())[-1] for col in cat_cols}

    # assign a unique integer to each category
    cat2int = {col: {cat: i for i, cat in enumerate(X[col].unique(), 1)}
               for col in X[cat_cols]}
    for col in cat2int:
        cat2int[col][np.nan] = 0

    # create embedding layers for each categorical feature
    categorical_embedding = {
        col: Embedding(X[col].nunique() + 1, cat_embed_dict[col], name=f'embedding_{col}')
        for col in cat_cols}

    dynamic_inputs = Input(shape=(len(X.iloc[0, 0]), num_dynamic_features))
    static_inputs = Input(shape=(num_static_features,))
    categorical_inputs = [Input(shape=(1,), dtype=tf.int32) for _ in range(num_categorical_features)]

    # First, convolutional layer to reduce the number of time steps
    conv_inputs = Conv1D(filters, 3 * strides, strides=strides, padding="same", activation='silu')(dynamic_inputs)

    dynamic_embedding = EmbeddingLayer(num_embedding)(conv_inputs)

    variable_selection = VariableSelectionNetwork(num_hidden, name='vsn_past_features')(dynamic_embedding)

    # recurrent encoder
    encode_out, *encode_states = RecurrentLayer(num_hidden, return_state=True, use_tkan=use_tkan, name='encoder')(
        variable_selection)

    # feed forward
    all_context = AddAndNorm()([Gate()(encode_out), variable_selection])

    # GRN using TKAN before attention
    enriched = GRN(num_hidden)(all_context)

    # attention
    attention_output = MultiHeadAttention(num_heads=num_heads, key_dim=enriched.shape[-1]
                                          )(enriched, enriched, enriched)
    attention_flattened = KANLinear(num_hidden)(Flatten()(attention_output))

    # Flatten the attention output and predict the future sequence
    # concatenate the flattened attention output with the static output and the categorical embeddings
    flattened_output = Concatenate()([attention_flattened, static_inputs] +
                                     [Flatten()(categorical_embedding[col](categorical_inputs[i]))
                                      for i, col in enumerate(cat_cols)])
    dense_output = KANLinear(n_classes)(flattened_output)

    return Model(inputs=[dynamic_inputs, static_inputs, *categorical_inputs], outputs=dense_output), cat2int

I don't know if this issue is particular to my data, or reflects an issue with how the grid is updated in the KAN layers.

Missing dataset

Hi!
I found super super interesting your job. I read your paper. Still wanting to replicate to understand how your implementation works.
I tried to run the example.ipynb, but dataset is missing
Is possible to be provided?
I was capable to make my own dataset downloading the hourly data from your list of assets from binance, but I still wondering if is possible to get access to your original dataset.

Addtionally, I was able to run your example with my dataset, but I wonder if it can be trained in my mac M2, or do I need GPU for using and training?
(
CORRECTION: Yes it was possible to run it in my M2. it took 17 min
image

)

Thank you so much for any comment
And again, CONGRATULATIONS for your super amazing work

Best

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.