Coder Social home page Coder Social logo

opensr's Introduction

OpenSR: Open-Modality Speech Recognition via Maintaining Multi-Modality Alignment

Xize Cheng*, Tao Jin*, Linjun Li*, Wang Lin, Xinyu Duan, Zhou Zhao | Zhejiang University & Huawei Cloud

PyTorch Implementation of OpenSR (ACL'23 Oral): an open modality training framework that can be trained on a single modality and applied to multiple modalities.

Zero-shot

If you find OpenSR useful in your research, please use the following BibTeX entry for citation.

@misc{cheng2023opensr,
      title={OpenSR: Open-Modality Speech Recognition via Maintaining Multi-Modality Alignment}, 
      author={Xize Cheng and Tao Jin and Linjun Li and Wang Lin and Xinyu Duan and Zhou Zhao},
      year={2023},
      eprint={2306.06410},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Installation

First, create a conda virtual environment and activate it:

conda create -n openSR python=3.8 -y
conda activate openSR

Then, clone this directory:

git clone https://github.com/Exgc/OpenSR.git
cd OpenSR

Lastly, install Fairseq and the other packages:

pip install -r requirements.txt
cd fairseq
pip install --editable ./

Open-Modality Speech Recognition (OpenSR)

1. Data preparation

Follow the steps in preparation to pre-process:

  • LRS2 and LRS2-COMMON dataset

2. Audio-Visual Alignment Learning

Refer to the audio-visual speech pre-training model AV-HuBERT. The pretraining checkpoints can be found at here.

3. Decoder Training with Audio only

Suppose {train,valid}.tsv are saved at /path/to/data, {train,valid}.wrd are saved at /path/to/labels, the configuration file is saved at /path/to/conf/conf-name, the model will be saved at /path/to/checkpoint.

To train the decoder with audio only, we run with the settings of opensr/opensr_large_vox_audio.yaml:

$ cd opensr
$ fairseq-hydra-train --config-dir /path/to/conf/ --config-name opensr/opensr_large_vox_audio.yaml \
  task.data=/path/to/data task.label_dir=/path/to/label \
  task.tokenizer_bpe_model=/path/to/tokenizer model.w2v_path=/path/to/checkpoint \
  hydra.run.dir=/path/to/experiment/opensr common.user_dir=`pwd`

4. Tuning of the Target-domain Decoder

Full-Shot

We further tune the model with the entire visual utterances, we run with the setting of opensr/large_vox_video.yaml or opensr/large_vox_audio_video.yaml:

$ cd opensr
$ cp /path/to/experiment/opensr/checkpoint_best.pt /path/to/experiment/full-shot/checkpoint_last.pt
$ fairseq-hydra-train --config-dir /path/to/conf/ --config-name opensr/large_vox_video.yaml \
  task.data=/path/to/data task.label_dir=/path/to/label \
  task.tokenizer_bpe_model=/path/to/tokenizer model.w2v_path=/path/to/checkpoint \
  hydra.run.dir=/path/to/experiment/full-shot common.user_dir=`pwd`

Few-Shot

We tune the model with visual speech of common words only, we run with the setting of prompt/large_vox_base_{10,20,50,100}.yaml The number of the clustering centers can be defined with the prompt_strategy: base_{number of centers} in the yaml.

$ cd opensr
$ cp /path/to/experiment/opensr/checkpoint_best.pt /path/to/experiment/few-shot/checkpoint_last.pt
$ fairseq-hydra-train --config-dir /path/to/conf/ --config-name prompt/large_vox_base_{10,20,50,100}.yaml \
  task.data=/path/to/data task.label_dir=/path/to/label \
  task.tokenizer_bpe_model=/path/to/tokenizer model.w2v_path=/path/to/checkpoint \
  hydra.run.dir=/path/to/experiment/few-shot common.user_dir=`pwd`

Zero-Shot & Inference

Suppose the test.tsv and test.wrd are the video list and transcripts of the split to be decoded, saved at /path/to/data. task.normalize needs to be consistent with the value used during fine-tuning. Decoding results will be saved at /path/to/experiment/decode/s2s/test.

  • For Zero-Shot, we directly inference with the model saved in /path/to/experiment/opensr
  • For Full-Shot and Few-Shot, we inference with the model saved in /path/to/experiment/{full-shot,few-shot}
$ cd opensr
$ python -B infer_s2s.py --config-dir ./conf/ --config-name conf-name \
  dataset.gen_subset=test common_eval.path=/path/to/experiment/{opensr,few-shot,full-shot} \
  common_eval.results_path=/path/to/experiment/decode/s2s/test \
  override.modalities=['video'] common.user_dir=`pwd`

The command above uses the default decoding hyperparameter, which can be found in conf/s2s_decode.yaml. override.modalities can be set to ['video'] (for lip reading), or ['audio'] (for ASR) or ['audio','video'] (for audio-visual speech recognition).These parameters can be configured from the command line. For example, to search with a beam size of 20, we can append the command above with generation.beam=20. Important parameters include:

  • generation.beam
  • generation.lenpen

opensr's People

Contributors

exgc avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

hcmsuper

opensr's Issues

preval.txt

I noticed the following step in the data preprocessing process: preval_fids = set([f'short-pretrain/{x.strip()}' for x in open(os.path.join(args.lrs2, 'preval.txt')).readlines()]). Could you please clarify what 'preval.txt' refers to? I couldn't find this file in my LRS2 dataset. Is it all the IDs from the previously processed 'short-pretrain

ImportError: cannot import name 'metrics' from 'fairseq' (unknown location)

When I am training the decoder with audio, I encounter the following issue:

[2023-11-29 22:38:41,125][opensr.hubert_dataset][INFO] - pad_audio=True, random_crop=False, normalize=True, max_sample_size=500, seqs2seq data=True,
[2023-11-29 22:38:41,125][opensr.hubert_dataset][INFO] - Noise wav: /usr/zzs/data/mvlrs_v1/noise/babble/train.tsv->1 wav, Prob: 0.25, SNR: 0, Number of mixture: 1
[2023-11-29 22:38:42,169][fairseq.logging.progress_bar][WARNING] - tensorboard not found, please install with: pip install tensorboard
[2023-11-29 22:38:42,170][fairseq.trainer][INFO] - begin training epoch 1
[2023-11-29 22:38:42,171][fairseq_cli.train][INFO] - Start iterating over samples
0%| | 0/201 [00:00<?, ?it/s]Traceback (most recent call last):
File "", line 1, in
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
exitcode = _main(fd, parent_sentinel)
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/multiprocessing/spawn.py", line 125, in _main
prepare(preparation_data)
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/multiprocessing/spawn.py", line 236, in prepare
_fixup_main_from_path(data['init_main_from_path'])
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/multiprocessing/spawn.py", line 287, in _fixup_main_from_path
main_content = runpy.run_path(main_path,
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/runpy.py", line 265, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/runpy.py", line 97, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/aa/anaconda3/envs/openSR/bin/fairseq-hydra-train", line 5, in
from fairseq_cli.hydra_train import cli_main
File "/home/aa/OpenSR/fairseq/fairseq_cli/hydra_train.py", line 11, in
from fairseq_cli.train import main as pre_main
File "/home/aa/OpenSR/fairseq/fairseq_cli/train.py", line 30, in
from fairseq import (
File "/home/aa/OpenSR/fairseq/fairseq/checkpoint_utils.py", line 26, in
from fairseq.models import FairseqDecoder, FairseqEncoder
File "/home/aa/OpenSR/fairseq/fairseq/models/init.py", line 225, in
import_models(models_dir, "fairseq.models")
File "/home/aa/OpenSR/fairseq/fairseq/models/init.py", line 207, in import_models
importlib.import_module(namespace + "." + model_name)
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/importlib/init.py", line 127, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "/home/aa/OpenSR/fairseq/fairseq/models/wav2vec/init.py", line 6, in
from .wav2vec import * # noqa
File "/home/aa/OpenSR/fairseq/fairseq/models/wav2vec/wav2vec.py", line 25, in
from fairseq.tasks import FairseqTask
File "/home/aa/OpenSR/fairseq/fairseq/tasks/init.py", line 15, in
from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa
File "/home/aa/OpenSR/fairseq/fairseq/tasks/fairseq_task.py", line 13, in
from fairseq import metrics, search, tokenizer, utils
ImportError: cannot import name 'metrics' from 'fairseq' (unknown location)
Traceback (most recent call last):
File "", line 1, in
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
exitcode = _main(fd, parent_sentinel)
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/multiprocessing/spawn.py", line 125, in _main
prepare(preparation_data)
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/multiprocessing/spawn.py", line 236, in prepare
_fixup_main_from_path(data['init_main_from_path'])
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/multiprocessing/spawn.py", line 287, in _fixup_main_from_path
main_content = runpy.run_path(main_path,
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/runpy.py", line 265, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/runpy.py", line 97, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/aa/anaconda3/envs/openSR/bin/fairseq-hydra-train", line 5, in
from fairseq_cli.hydra_train import cli_main
File "/home/aa/OpenSR/fairseq/fairseq_cli/hydra_train.py", line 11, in
from fairseq_cli.train import main as pre_main
File "/home/aa/OpenSR/fairseq/fairseq_cli/train.py", line 30, in
from fairseq import (
File "/home/aa/OpenSR/fairseq/fairseq/checkpoint_utils.py", line 26, in
from fairseq.models import FairseqDecoder, FairseqEncoder
File "/home/aa/OpenSR/fairseq/fairseq/models/init.py", line 225, in
import_models(models_dir, "fairseq.models")
File "/home/aa/OpenSR/fairseq/fairseq/models/init.py", line 207, in import_models
importlib.import_module(namespace + "." + model_name)
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/importlib/init.py", line 127, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "/home/aa/OpenSR/fairseq/fairseq/models/wav2vec/init.py", line 6, in
from .wav2vec import * # noqa
File "/home/aa/OpenSR/fairseq/fairseq/models/wav2vec/wav2vec.py", line 25, in
from fairseq.tasks import FairseqTask
File "/home/aa/OpenSR/fairseq/fairseq/tasks/init.py", line 15, in
from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa
File "/home/aa/OpenSR/fairseq/fairseq/tasks/fairseq_task.py", line 13, in
from fairseq import metrics, search, tokenizer, utils
ImportError: cannot import name 'metrics' from 'fairseq' (unknown location)

Decoder Training with Audio only

When I train the decoder with audio only, I encounter an AssertionError: Could not infer the task type from the following configuration {'_name': 'av_hubert_pretraining', 'is_s2s': True, 'data': '/usr/zzs/data/mvlrs_v1/29h_data', 'label_dir': '/usr/zzs/data/mvlrs_v1/29h_data', 'tokenizer_bpe_model': '/usr/zzs/data/mvlrs_v1/spm1000/spm_unigram1000.model', 'normalize': True, 'labels': ['wrd'], 'single_target': True, 'fine_tuning': True, 'stack_order_audio': 4, 'tokenizer_bpe_name': 'sentencepiece', 'max_sample_size': 500, 'modalities': ['audio'], 'image_aug': True, 'pad_audio': True, 'random_crop': False, 'noise_prob': 0.25, 'noise_snr': 0, 'noise_wav': '???'}.

Available argparse tasks: dict_keys(['translation', 'translation_lev', 'speech_to_text', 'hubert_pretraining', 'speech_unit_modeling', 'multilingual_translation', 'multilingual_masked_lm', 'text_to_speech', 'frm_text_to_speech', 'denoising', 'multilingual_denoising', 'legacy_masked_lm', 'semisupervised_translation', 'translation_multi_simple_epoch', 'simul_speech_to_text', 'simul_text_to_text', 'sentence_prediction', 'sentence_prediction_adapters', 'cross_lingual_lm', 'translation_from_pretrained_bart', 'sentence_ranking', 'speech_to_speech', 'online_backtranslation', 'audio_pretraining', 'audio_finetuning', 'multilingual_language_modeling', 'language_modeling', 'masked_lm', 'translation_from_pretrained_xlm', 'dummy_lm', 'dummy_masked_lm', 'dummy_mt']).

Available hydra tasks: dict_keys(['translation', 'translation_lev', 'hubert_pretraining', 'speech_unit_modeling', 'simul_text_to_text', 'sentence_prediction', 'sentence_prediction_adapters', 'audio_pretraining', 'audio_finetuning', 'multilingual_language_modeling', 'language_modeling', 'masked_lm', 'translation_from_pretrained_xlm', 'dummy_lm', 'dummy_masked_lm']).

mean faces

In the process of preprocessing LRS2, this script generates mouth Regions of Interest (ROIs) in ${lrs2}/video. It shards all the videos listed in ${lrs2}/file.list into ${nshard} portions and generates mouth ROIs for the ${rank}-th shard, where "rank" is an integer in the range [0, nshard-1]. Face detection and landmark prediction are performed using the Dlib library.

The links to download the cnn_detector, face_detector, and mean_face can be found in the help message. However, it's not clear where to download the "mean_face," and further guidance is needed.

AttributeError: 'AVHubertSeq2Seq' object has no attribute 'prompt_init'

Dear author, hello. When I fine-tune the pre-trained model using voice and attempt few-shot fine-tuning with visual information, I encountered the following error. Is it possible that there is a lack of definition for this class in the code?
Traceback (most recent call last):
File "/home/aa/anaconda3/envs/openSR/bin/fairseq-hydra-train", line 8, in
sys.exit(cli_main())
File "/home/aa/OpenSR/fairseq/fairseq_cli/hydra_train.py", line 76, in cli_main
hydra_main()
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/site-packages/hydra/main.py", line 32, in decorated_main
_run_hydra(
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/site-packages/hydra/_internal/utils.py", line 346, in _run_hydra
run_and_report(
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/site-packages/hydra/_internal/utils.py", line 201, in run_and_report
raise ex
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/site-packages/hydra/_internal/utils.py", line 198, in run_and_report
return func()
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/site-packages/hydra/_internal/utils.py", line 347, in
lambda: hydra.run(
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/site-packages/hydra/internal/hydra.py", line 107, in run
return run_job(
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/site-packages/hydra/core/utils.py", line 129, in run_job
ret.return_value = task_function(task_cfg)
File "/home/aa/OpenSR/fairseq/fairseq_cli/hydra_train.py", line 45, in hydra_main
distributed_utils.call_main(cfg, pre_main)
File "/home/aa/OpenSR/fairseq/fairseq/distributed/utils.py", line 369, in call_main
main(cfg, **kwargs)
File "/home/aa/OpenSR/fairseq/fairseq_cli/train.py", line 97, in main
model = task.build_model(cfg.model)
File "/home/aa/OpenSR/fairseq/fairseq/tasks/fairseq_task.py", line 325, in build_model
model = models.build_model(cfg, self)
File "/home/aa/OpenSR/fairseq/fairseq/models/init.py", line 96, in build_model
return model.build_model(cfg, task)
File "/home/aa/OpenSR/opensr/hubert_asr.py", line 489, in build_model
encoder = HubertEncoderWrapper(encoder
, cfg.prompting, cfg.prompt_strategy)
File "/home/aa/OpenSR/opensr/hubert_asr.py", line 391, in init
self.w2v_model.prompt_init(strategy=strategy)
File "/home/aa/anaconda3/envs/openSR/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1695, in getattr
raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'")
AttributeError: 'AVHubertSeq2Seq' object has no attribute 'prompt_init'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.