Coder Social home page Coder Social logo

alexanderzhanxiake / tts-cgan Goto Github PK

View Code? Open in Web Editor NEW

This project forked from imics-lab/tts-cgan

0.0 0.0 0.0 7.34 MB

TTS-CGAN: A Transformer Time-Series Conditional GAN for Biosignal Data Augmentation

License: Apache License 2.0

Shell 0.04% Python 2.49% Jupyter Notebook 97.47%

tts-cgan's Introduction

TTS-CGAN: A Transformer Time-Series Conditional GAN for Biosignal Data Augmentation


This repository contains code from the paper "TTS-CGAN: A Transformer Time-Series Conditional GAN for Biosignal Data Augmentation".


Abstract: Signal measurement appearing in the form of time series is one of the most common types of data used in pervasive computing applications. Such datasets are often small in size, expensive to collect and annotate, and might involve privacy issues, which hinder our ability to train large, state-of-the-art deep learning models. For time-series data, the suite of data augmentation strategies we can use to expand the size of the dataset is limited by the need to maintain the basic properties of the signal. Generative Adversarial Networks (GANs) can be utilized as another data augmentation tool. In this paper, we present TTS-CGAN, a transformer-based conditional GAN model that can be trained on existing multi-class datasets and generate class-specific synthetic time-series sequences of arbitrary length. We elaborate on the model architecture and design strategies. Synthetic sequences generated by our models are almost indistinguishable from real ones and can be used to complement or replace real signals of the same type, thus achieving the goal of data augmentation. To evaluate the quality of the generated data, we modify the wavelet coherence metric to compare the similarity between real and synthetic signals. Finally, qualitative visualizations using t-SNE and quantitative comparisons using the discriminative and predictive power of the synthetic data show that TTS-CGAN outperforms other state-of-the-art GAN models built for time-series data generation.

Major Contributions:

Use Transformer GAN to generate multi-category synthetic time-series data.

Use Wavelet Coherence score to compare the similarity between two sets of signals.

The TTS-CGAN Architecture TTS-CGAN Architecture


Code structure:

TransCGN_model.py: The tts-cgan model architecture. It contains the code of the transformor-based generator and discriminator. The generator has embedded labels. The discriminator has two classficatio heads, one is for adversarial classification and another one is for categorical classfication.
trainCGAN.py: Contains code for model initialization, load dataset, and training process. Several intermediate results will show on the Tensorboard.

Dataloader.py: The PyTorch dataloader written for loading mitbih heat beat signals. Download the dataset mitbih_train.csv and mitbih_test.csv from here and save it to your code directory.

synDataloader.py: The PyTorch dataloader written for loading mitbih synthetic signals from pre-trained generator.

mitbih_checkpoint: A pretrained TTS-CGAN checkpoint.

cgan_functions.py, utils.py: contains some helper functions.

adamw.py: adamw implementation.

cfg.py: the definition of parse arguments.

mitbit_Train_CGAN.py: a script used for start model training.

classification.ipynb: Post-hoc classficaiton examples used for generating the plots in paper Figure 10.

LookAtData.ipynb: show some plots of real mitbih heartbeat signals and synthetic signals.

The folder Label-embedding: contains code used for generating the plots in paper Figure 5.

Implementation structions:

To re-train the model:

python mitbih_Train_CGAN.py

Modify the parse arguments to fit for your dataset.


tts-cgan's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.