Coder Social home page Coder Social logo

bighuang624 / dsanet Goto Github PK

View Code? Open in Web Editor NEW
253.0 13.0 59.0 731 KB

Code for the CIKM 2019 paper "DSANet: Dual Self-Attention Network for Multivariate Time Series Forecasting".

Home Page: https://kyonhuang.top/publication/dual-self-attention-network

Python 100.00%
pytorch

dsanet's Introduction

Dual Self-Attention Network for Multivariate Time Series Forecasting

20.10.26 Update: Due to the difficulty of installation and code maintenance caused by frequent updates of pytorch-lightning, the code does not work correctly now. We hope this code is useful for your reference, especially the part about the model, however, we are sorry that we will no longer maintain the project. We recommend you to refer to other similar applications of self-attention mechanism in time series, such as "Enhancing the Locality and Breaking the Memory Bottleneck of Transformer on Time Series Forecasting" and https://github.com/maxjcohen/transformer.

This project is the PyTorch implementation of the paper "DSANet: Dual Self-Attention Network for Multivariate Time Series Forecasting", in which we propose a dual self-attention network (DSANet) for multivariate time series forecasting. The network architecture is illustrated in the following figure, and more details about the effect of each component can be found in the paper.

Requirements

  • Python 3.5 or above
  • PyTorch 1.1 or above
  • pytorch-lightning

How to run

You need to prepare the dataset first. Check here.

# clone project
git clone https://github.com/bighuang624/DSANet.git

# install dependencies
cd DSANet
pip install requirements.txt

# run
python single_cpu_trainer.py --data_name {data_name} --n_multiv {n_multiv}

Notice: At present, we find that there are some bugs (presumably some problems left by the old version of pytorch-lightning) that make our code unable to run correctly on GPUs. You can currently run the code on the CPU as above.

Citation

If our code is helpful for your research, please cite our paper:

@inproceedings{Huang2019DSANet,
  author = {Siteng Huang and Donglin Wang and Xuehan Wu and Ao Tang},
  title = {DSANet: Dual Self-Attention Network for Multivariate Time Series Forecasting},
  booktitle = {Proceedings of the 28th ACM International Conference on Information and Knowledge Management (CIKM 2019)},
  month = {November},
  year = {2019},
  address = {Beijing, China}
}

Acknowledgement

Part of the code is heavily borrowed from jadore801120/attention-is-all-you-need-pytorch.

dsanet's People

Contributors

bighuang624 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dsanet's Issues

Questions about spliting dataset

Thanks for your code!

However, I have questions about spliting dataset. Supposing there are 10000 time steps data (10000 lines), and I want to split 60%, 20%, 20% as train, valid, test. So there will be 6000 lines in my train set. But I'm not sure in the valid set should there be 2000 lines or 6000+2000=8000 lines?

Data should be sorted in date ascending/descending order?

I find that the example dataset from another repo has no header nor index in the .txt files (electricity, exchange_rate, etc.). I am not sure whether the data is sorted with date ascending (oldest date first) or descending (newest date first) order?

code

Your work is great, can you share your code?

Peaks/troughs Not Large Enough

I have had some promising results running the model on simple univariate time series data. I performed some hyperparameter tuning, however, my peaks are troughs are consistently too small (see below).

Is there a parameter(s) to better tune this to allow for more volatility? Any ideas are appreciated. Thank you!

image

saving trained model

Hi,

I would like to know how can I save and load the model. I used torch.save to save the model but I received and error. Thank you

single_cpu_trainer.py: error: argument --n_multiv: invalid int value: '{n_multiv}'

Hello,
When I run this command,"python single_cpu_trainer.py --data_name {data_name} --n_multiv {n_multiv}", I receive the following output,

usage: single_cpu_trainer.py [-h] [--test_tube_save_path TEST_TUBE_SAVE_PATH]
[--model_save_path MODEL_SAVE_PATH]
[--local LOCAL] [--n_kernels N_KERNELS]
[-w_kernel W_KERNEL] [--d_model D_MODEL]
[--d_inner D_INNER] [--d_k D_K] [--d_v D_V]
[--n_head N_HEAD] [--n_layers N_LAYERS]
[--drop_prob DROP_PROB] [--data_name DATA_NAME]
[--data_dir DATA_DIR] [--n_multiv N_MULTIV]
[--window WINDOW] [--horizon HORIZON]
[--learning_rate LEARNING_RATE]
[--optimizer_name OPTIMIZER_NAME]
[--criterion CRITERION] [--batch_size BATCH_SIZE]
single_cpu_trainer.py: error: argument --n_multiv: invalid int value: '{n_multiv}'

Then, I tried running only this command, "python single_cpu_trainer.py --data_name {data_name} ", my output is,

RUNNING ON CPU
loading model...
model built
gpu available: False, used: False
tng data loader called
test data loader called
val data loader called
Name
0 sgsf
1 sgsf.conv2
2 sgsf.in_linear
3 sgsf.out_linear
4 sgsf.layer_stack
5 sgsf.layer_stack.0
6 sgsf.layer_stack.0.slf_attn
7 sgsf.layer_stack.0.slf_attn.w_qs
8 sgsf.layer_stack.0.slf_attn.w_ks
9 sgsf.layer_stack.0.slf_attn.w_vs
10 sgsf.layer_stack.0.slf_attn.attention
11 sgsf.layer_stack.0.slf_attn.attention.dropout
12 sgsf.layer_stack.0.slf_attn.attention.softmax
13 sgsf.layer_stack.0.slf_attn.layer_norm
14 sgsf.layer_stack.0.slf_attn.fc
15 sgsf.layer_stack.0.slf_attn.dropout
16 sgsf.layer_stack.0.pos_ffn
17 sgsf.layer_stack.0.pos_ffn.w_1
18 sgsf.layer_stack.0.pos_ffn.w_2
19 sgsf.layer_stack.0.pos_ffn.layer_norm
20 sgsf.layer_stack.0.pos_ffn.dropout
21 sgsf.layer_stack.1
22 sgsf.layer_stack.1.slf_attn
23 sgsf.layer_stack.1.slf_attn.w_qs
24 sgsf.layer_stack.1.slf_attn.w_ks
25 sgsf.layer_stack.1.slf_attn.w_vs
26 sgsf.layer_stack.1.slf_attn.attention
27 sgsf.layer_stack.1.slf_attn.attention.dropout
28 sgsf.layer_stack.1.slf_attn.attention.softmax
29 sgsf.layer_stack.1.slf_attn.layer_norm
.. ...
178 slsf.layer_stack.4.slf_attn.attention.softmax
179 slsf.layer_stack.4.slf_attn.layer_norm
180 slsf.layer_stack.4.slf_attn.fc
181 slsf.layer_stack.4.slf_attn.dropout
182 slsf.layer_stack.4.pos_ffn
183 slsf.layer_stack.4.pos_ffn.w_1
184 slsf.layer_stack.4.pos_ffn.w_2
185 slsf.layer_stack.4.pos_ffn.layer_norm
186 slsf.layer_stack.4.pos_ffn.dropout
187 slsf.layer_stack.5
188 slsf.layer_stack.5.slf_attn
189 slsf.layer_stack.5.slf_attn.w_qs
190 slsf.layer_stack.5.slf_attn.w_ks
191 slsf.layer_stack.5.slf_attn.w_vs
192 slsf.layer_stack.5.slf_attn.attention
193 slsf.layer_stack.5.slf_attn.attention.dropout
194 slsf.layer_stack.5.slf_attn.attention.softmax
195 slsf.layer_stack.5.slf_attn.layer_norm
196 slsf.layer_stack.5.slf_attn.fc
197 slsf.layer_stack.5.slf_attn.dropout
198 slsf.layer_stack.5.pos_ffn
199 slsf.layer_stack.5.pos_ffn.w_1
200 slsf.layer_stack.5.pos_ffn.w_2
201 slsf.layer_stack.5.pos_ffn.layer_norm
202 slsf.layer_stack.5.pos_ffn.dropout
203 ar
204 ar.linear
205 W_output1
206 dropout
207 active_func

                          Type    Params

0 Single_Global_SelfAttn_Module 18949696
1 Conv2d 2080
2 Linear 16896
3 Linear 16416
4 ModuleList 18914304
5 EncoderLayer 3152384
6 MultiHeadAttention 1051648
7 Linear 262656
8 Linear 262656
9 Linear 262656
10 ScaledDotProductAttention 0
11 Dropout 0
12 Softmax 0
13 LayerNorm 1024
14 Linear 262656
15 Dropout 0
16 PositionwiseFeedForward 2100736
17 Conv1d 1050624
18 Conv1d 1049088
19 LayerNorm 1024
20 Dropout 0
21 EncoderLayer 3152384
22 MultiHeadAttention 1051648
23 Linear 262656
24 Linear 262656
25 Linear 262656
26 ScaledDotProductAttention 0
27 Dropout 0
28 Softmax 0
29 LayerNorm 1024
.. ... ...
178 Softmax 0
179 LayerNorm 1024
180 Linear 262656
181 Dropout 0
182 PositionwiseFeedForward 2100736
183 Conv1d 1050624
184 Conv1d 1049088
185 LayerNorm 1024
186 Dropout 0
187 EncoderLayer 3152384
188 MultiHeadAttention 1051648
189 Linear 262656
190 Linear 262656
191 Linear 262656
192 ScaledDotProductAttention 0
193 Dropout 0
194 Softmax 0
195 LayerNorm 1024
196 Linear 262656
197 Dropout 0
198 PositionwiseFeedForward 2100736
199 Conv1d 1050624
200 Conv1d 1049088
201 LayerNorm 1024
202 Dropout 0
203 AR 65
204 Linear 65
205 Linear 65
206 Dropout 0
207 Tanh 0

[208 rows x 3 columns]
Traceback (most recent call last):
File "single_cpu_trainer.py", line 102, in
main(hyperparams)
File "single_cpu_trainer.py", line 70, in main
trainer.fit(model)
File "C:\Users\abish\anaconda3\envs\Python36\lib\site-packages\pytorch_lightning\models\trainer.py", line 567, in fit
self.__run_pretrain_routine(model)
File "C:\Users\abish\anaconda3\envs\Python36\lib\site-packages\pytorch_lightning\models\trainer.py", line 778, in __run_pretrain_routine
self.validate(model, dataloader, self.nb_sanity_val_steps, ds_i)
File "C:\Users\abish\anaconda3\envs\Python36\lib\site-packages\pytorch_lightning\models\trainer.py", line 438, in validate
output = self.__validation_forward(model, data_batch, batch_i, dataloader_i)
File "C:\Users\abish\anaconda3\envs\Python36\lib\site-packages\pytorch_lightning\models\trainer.py", line 402, in _validation_forward
output = model.validation_step(*args)
File "C:\Program Files (x86)\DSAnet\DSANet\model.py", line 267, in validation_step
y_hat = self.forward(x)
File "C:\Program Files (x86)\DSAnet\DSANet\model.py", line 217, in forward
sgsf_output, *
= self.sgsf(x)
File "C:\Users\abish\anaconda3\envs\Python36\lib\site-packages\torch\nn\modules\module.py", line 547, in call
result = self.forward(*input, **kwargs)
File "C:\Program Files (x86)\DSAnet\DSANet\model.py", line 58, in forward
x = x.view(-1, self.w_kernel, self.window, self.n_multiv)
TypeError: view(): argument 'size' must be tuple of ints, but found element of type NoneType at pos 4

My packages :

packages in environment at C:\Users\abish\anaconda3\envs\Python36:

Name Version Build Channel

absl-py 0.9.0 pypi_0 pypi
blas 1.0 mkl
cachetools 4.1.0 pypi_0 pypi
certifi 2020.4.5.1 py36h9f0ad1d_0 conda-forge
cffi 1.14.0 py36h7a1dbc1_0
chardet 3.0.4 pypi_0 pypi
cudatoolkit 10.0.130 0
freetype 2.9.1 ha9979f8_1
future 0.18.2 pypi_0 pypi
google-auth 1.13.1 pypi_0 pypi
google-auth-oauthlib 0.4.1 pypi_0 pypi
grpcio 1.28.1 pypi_0 pypi
icc_rt 2019.0.0 h0cc432a_1
idna 2.9 pypi_0 pypi
imageio 2.8.0 pypi_0 pypi
intel-openmp 2020.0 166
jpeg 9b hb83a4c4_2
libblas 3.8.0 15_mkl conda-forge
libcblas 3.8.0 15_mkl conda-forge
liblapack 3.8.0 15_mkl conda-forge
libpng 1.6.37 h2a8f88b_0
libtiff 4.1.0 h56a325e_0
markdown 3.2.1 pypi_0 pypi
mkl 2020.0 166
mkl-service 2.3.0 py36hb782905_0
mkl_fft 1.0.15 py36h14836fe_0
mkl_random 1.1.0 py36h675688f_0
ninja 1.9.0 py36h74a9793_0
numpy 1.16.4 py36hc71023c_0 conda-forge
oauthlib 3.1.0 pypi_0 pypi
olefile 0.46 py36_0
pandas 0.20.3 py36hce827b7_2 anaconda
pillow 7.0.0 py36hcc1f983_0
pip 20.0.2 py36_1
protobuf 3.11.3 pypi_0 pypi
pyasn1 0.4.8 pypi_0 pypi
pyasn1-modules 0.2.8 pypi_0 pypi
pycparser 2.20 py_0
python 3.6.10 h9f7ef89_1
python-dateutil 2.8.1 py_0 anaconda
python_abi 3.6 1_cp36m conda-forge
pytorch 1.2.0 py3.6_cuda100_cudnn7_1 pytorch
pytorch-lightning 0.4.6 pypi_0 pypi
pytz 2019.3 py_0 anaconda
requests 2.23.0 pypi_0 pypi
requests-oauthlib 1.3.0 pypi_0 pypi
rsa 4.0 pypi_0 pypi
scikit-learn 0.20.2 py36h343c172_0
scipy 1.4.1 py36h9439919_0
setuptools 46.1.3 py36_0
six 1.14.0 py36_0
sqlite 3.31.1 he774522_0
tb-nightly 1.15.0a20190708 pypi_0 pypi
tensorboard 2.2.0 pypi_0 pypi
tensorboard-plugin-wit 1.6.0.post3 pypi_0 pypi
test-tube 0.6.9 pypi_0 pypi
tk 8.6.8 hfa6e2cd_0
torchvision 0.3.0 py36_cu100_1 pytorch
tqdm 4.45.0 pypi_0 pypi
urllib3 1.25.8 pypi_0 pypi
vc 14.1 h0510ff6_4
vs2015_runtime 14.16.27012 hf0eaf9b_1
werkzeug 1.0.1 pypi_0 pypi
wheel 0.34.2 py36_0
wincertstore 0.2 py36h7fe50ca_0
xz 5.2.4 h2fa13f4_4
zlib 1.2.11 h62dcd97_3
zstd 1.3.7 h508b16e_0.

Since, I am new to programming and python, I tried my best to find a solution to the bugs, but I was not able to figure it out. Any help would be appreciated. Thanks.

pytorch_lightning version

In the requirements.txt file, you specify pytorch_lightning version as >=0.4.6. Therefore one (like me) may install a more current version than 0.4.6 like the most recent 1.4.1. But there is not models module since 0.4.8. I would highly recommend specifying pytorch_lightning version to avoid non-existing modules errors.

No `train_dataloader()` method defined.

大佬 非常感谢您公开的代码 我想引用您的文章 但是我在运行代码的时候遇到以下问题 请问这个是什么原因导致的 pytorch_lightning.utilities.exceptions.MisconfigurationException: No train_dataloader() method defined. Lightning Trainer expects as minimum a training_step(), train_dataloader() and configure_optimizers() to be defined. 非常感谢您的回复

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.