This repository contains PyTorch code replicating the paper Keyword Transformer: A Self-Attention Model for Keyword Spotting. Currently, KWT-1, KWT-2, and KWT-3 without distillation, with 40x1 input patches from the spectrogram, are supported. Distillation token (from MHAtt-RNN) and support for other sizes of input patches are to follow.
Replicated model performance on v1 and v2 datasets vs literature data:
V1-12 Accuracy | V2-12 Accuracy | V2-35 Accuracy | # Parameters | |||||
---|---|---|---|---|---|---|---|---|
Replicated | Paper | Replicated | Paper | Replicated | Paper | Replicated | Paper | |
KWT-3 | 95.94% | 97.24% | 97.40% | 98.54% | 95.72% | 97.51% | 557k | 607k |
KWT-2 | 95.46% | 97.36% | 97.08% | 98.21% | 95.85% | 97.53% | 2,394k | 2,394k |
KWT-1 | 95.03% | 97.05% | 95.99% | 97.72% | 94.75% | 96.85% | 5,361k | 5,361k |
Note regarding number of parameters: following the model details in the paper, KWT-2 and KWT-3 have the same number of parameters as the paper, but KWT-1 has less. The source of the discrepancy is still being investigated, but this might be related to the observed model performance degradation.
Clone the repository:
git clone https://github.com/wdjose/keyword-transformer.git
cd keyword-transformer
Create the google-speech-commands folder and download and extract the google-speech-commands dataset:
mkdir -p data/google-speech-commands
cd data/google-speech-commands
mkdir -p data1 data2 data3
wget https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz
wget https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies 'https://docs.google.com/uc?export=download&id=1OAN3h4uffi5HS7eb7goklWeI2XPm1jCS' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1OAN3h4uffi5HS7eb7goklWeI2XPm1jCS" -O data_all_v2.zip && rm -rf /tmp/cookies.txt
tar -xzf speech_commands_v0.01.tar.gz -C data1
tar -xzf speech_commands_v0.02.tar.gz -C data2
unzip data_all_v2.zip -d data3
rm speech_commands_v0.01.tar.gz speech_commands_v0.02.tar.gz data_all_v2.zip
cd ../..
Clone the kws_streaming
subdirectory in the google-research repository:
svn export https://github.com/google-research/google-research/trunk/kws_streaming
Train kwt1, kwt2, and kwt3 variants on v1 and v2 datasets with no distillation (this generates 12M augmented MFCC samples with TensorFlow for v1 and v2 datasets):
bash train.sh
For the purposes of this repository, version=1 (data1) corresponds to v1-12 (12 labels), version=2 (data2) corresponds to v2-12 (12 labels), and version=3 (data3) corresponds to v2-35 (35 labels). So technically, "version=3" refers to the modified v2 dataset (with 35 labels) as defined by the paper Streaming Keyword Spotting on Mobile Devices, from which the data augmentation code came (the kws_streaming
repository exported from above).