Coder Social home page Coder Social logo

soundctm's Introduction

SoundCTM: Uniting Score-based and Consistency Models for Text-to-Sound Generation

This repository is the official implementation of "SoundCTM: Uniting Score-based and Consistency Models for Text-to-Sound Generation"

Contact:

Checkpoints

For inference, both AudioLDM-s-full (for VAE's decoder+Vocoder) and SoundCTM checkpoints will be used.

Prerequisites

Install docker to your own server and build docker container:

docker build -t soundctm .

Then run scripts in the container.

Training

Please see ctm_train.sh and ctm_train.py and modify folder path dependeing on your environment.

Then run bash ctm_train.sh

Inference

Please see ctm_inference.sh and ctm_inference.py and modify folder path dependeing on your environment.

Then run bash ctm_inference.sh

Numerical evaluation

Please see numerical_evaluation.sh and numerical_evaluation.py and modify folder path dependeing on your environment.

Then run bash numerical_evaluation.sh

Dataset

Follow the instructions given in the AudioCaps repository for downloading the data. Data locations are needed to be spesificied in ctm_train.sh. You can also see some examples at data/train.csv.

WandB for logging

The training code also requires a Weights & Biases account to log the training outputs and demos. Create an account and log in with:

$ wandb login

Or you can also pass an API key as an environment variable WANDB_API_KEY. (You can obtain the API key from https://wandb.ai/authorize after logging in to your account.)

$ WANDB_API_KEY="12345x6789y..."

Citation

@article{saito2024soundctm,
  title={SoundCTM: Uniting Score-based and Consistency Models for Text-to-Sound Generation}, 
  author={Koichi Saito and Dongjun Kim and Takashi Shibuya and Chieh-Hsin Lai and Zhi Zhong and Yuhta Takida and Yuki Mitsufuji},
  journal={arXiv preprint arXiv:2405.18503},
  year={2024}
}

Reference

Part of the code is borrowed from the following repos. We would like to thank the authors of these repos for their contribution.

https://github.com/sony/ctm

https://github.com/declare-lab/tango

https://github.com/haoheliu/AudioLDM

https://github.com/haoheliu/audioldm_eval

soundctm's People

Contributors

koichi-saito-sony avatar ljrabbit avatar

Stargazers

Hai Carroll avatar  avatar  avatar Yongyi Zang avatar Yixiao Zhang avatar W. avatar Vedant Kalbag avatar Yu-Hua Chen avatar  avatar Mantra avatar Koolen Dasheppi avatar  avatar hejie avatar DS.Xu avatar Carl Thomé avatar João Felipe Santos avatar Pawel Cyrta avatar Emmanuel Infante avatar 王贤锐(Henry) avatar Kenny Falkær Olsen avatar Eloi Moliner Juanpere avatar Matthew avatar  avatar Song Li avatar  avatar Nguyễn Hữu Chiến Thắng avatar Glenn 'devalias' Grant avatar yearnyeen ho avatar Yuchao Zhang avatar  avatar Alef Iury avatar Jonathan Fischoff avatar Gurumurthi V Ramanan avatar Kun Zhou avatar  avatar owlwang avatar XiaHan avatar Rishikesh (ऋषिकेश) avatar Kellyxiaowei avatar  avatar  avatar  avatar Shojo Hakase avatar syddharth avatar toyxyz avatar  avatar ji avatar Buen avatar Omar Sanseviero avatar Pengfei Xuan avatar Slice avatar Yuan-Man avatar wblgers avatar Nickolay V. Shmyrev avatar Rohit Rajesh avatar MaxMax avatar  avatar zyser avatar  avatar FluuFlaaT avatar HeyangXue1997 avatar Yong Liu avatar  avatar Hayeong avatar  avatar Sang-gil Lee avatar Yatong Bai avatar Julian Tanke avatar Julian Tanke avatar

Watchers

Nickolay V. Shmyrev avatar Yoshifumi Ueno avatar Vedant Kalbag avatar

soundctm's Issues

Issue running inference

Hi, I'm facing the following import errors when running ctm_inference.sh:
ImportError: cannot import name 'PositionNet' from 'diffusers.models.embeddings' (/usr/local/lib/python3.8/dist-packages/diffusers/models/embeddings.py)

Here are the steps I followed to try running inference:

  1. docker build -t soundctm .
  2. docker run --gpus all --rm -it -v $(pwd):$(pwd) -w $(pwd) soundctm /bin/bash
  3. chmod +x ctm_inference.sh
  4. ./ctm_inference.sh

Below is the modified contents of ctm_inference.sh with all the correct paths:

python ctm_inference.py \
    --text_encoder_name "google/flan-t5-large" \
    --ctm_unet_model_config "configs/diffusion_model_config.json" \
    --training_args "ckpt/030000/summary.jsonl" \
    --model "ckpt/030000/pytorch_model.bin" \
    --ema_model "ckpt/030000/ema_0.999_030000.pt" \
    --test_file "data/test.csv" \
    --sampler 'determinisitc' --sampling_gamma 0. --omega 3.5 \
    --num_steps 1 --nu 1. --num_samples 1 --batch_size 1 \
    --output_dir "outputs/"

The full log is as shown below:

/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:249: FutureWarning: `AutoencoderTinyBlock` is deprecated and will be removed in version 0.29. Importing `AutoencoderTinyBlock` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AutoencoderTinyBlock`, instead.
  deprecate("AutoencoderTinyBlock", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:254: FutureWarning: `UNetMidBlock2D` is deprecated and will be removed in version 0.29. Importing `UNetMidBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D`, instead.
  deprecate("UNetMidBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:259: FutureWarning: `UNetMidBlock2DCrossAttn` is deprecated and will be removed in version 0.29. Importing `UNetMidBlock2DCrossAttn` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn`, instead.
  deprecate("UNetMidBlock2DCrossAttn", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:264: FutureWarning: `UNetMidBlock2DSimpleCrossAttn` is deprecated and will be removed in version 0.29. Importing `UNetMidBlock2DSimpleCrossAttn` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn`, instead.
  deprecate("UNetMidBlock2DSimpleCrossAttn", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:269: FutureWarning: `AttnDownBlock2D` is deprecated and will be removed in version 0.29. Importing `AttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnDownBlock2D`, instead.
  deprecate("AttnDownBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:274: FutureWarning: `CrossAttnDownBlock2D` is deprecated and will be removed in version 0.29. Importing `AttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D`, instead.
  deprecate("CrossAttnDownBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:279: FutureWarning: `DownBlock2D` is deprecated and will be removed in version 0.29. Importing `DownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import DownBlock2D`, instead.
  deprecate("DownBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:284: FutureWarning: `AttnDownEncoderBlock2D` is deprecated and will be removed in version 0.29. Importing `AttnDownEncoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnDownEncoderBlock2D`, instead.
  deprecate("AttnDownEncoderBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:289: FutureWarning: `AttnSkipDownBlock2D` is deprecated and will be removed in version 0.29. Importing `AttnSkipDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnSkipDownBlock2D`, instead.
  deprecate("AttnSkipDownBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:294: FutureWarning: `SkipDownBlock2D` is deprecated and will be removed in version 0.29. Importing `SkipDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SkipDownBlock2D`, instead.
  deprecate("SkipDownBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:299: FutureWarning: `ResnetDownsampleBlock2D` is deprecated and will be removed in version 0.29. Importing `ResnetDownsampleBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D`, instead.
  deprecate("ResnetDownsampleBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:304: FutureWarning: `SimpleCrossAttnDownBlock2D` is deprecated and will be removed in version 0.29. Importing `SimpleCrossAttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SimpleCrossAttnDownBlock2D`, instead.
  deprecate("SimpleCrossAttnDownBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:309: FutureWarning: `KDownBlock2D` is deprecated and will be removed in version 0.29. Importing `KDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KDownBlock2D`, instead.
  deprecate("KDownBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:314: FutureWarning: `KCrossAttnDownBlock2D` is deprecated and will be removed in version 0.29. Importing `KCrossAttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KCrossAttnDownBlock2D`, instead.
  deprecate("KCrossAttnDownBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:319: FutureWarning: `AttnUpBlock2D` is deprecated and will be removed in version 0.29. Importing `AttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnUpBlock2D`, instead.
  deprecate("AttnUpBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:324: FutureWarning: `CrossAttnUpBlock2D` is deprecated and will be removed in version 0.29. Importing `CrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import CrossAttnUpBlock2D`, instead.
  deprecate("CrossAttnUpBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:329: FutureWarning: `UpBlock2D` is deprecated and will be removed in version 0.29. Importing `UpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UpBlock2D`, instead.
  deprecate("UpBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:334: FutureWarning: `UpDecoderBlock2D` is deprecated and will be removed in version 0.29. Importing `UpDecoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UpDecoderBlock2D`, instead.
  deprecate("UpDecoderBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:339: FutureWarning: `AttnUpDecoderBlock2D` is deprecated and will be removed in version 0.29. Importing `AttnUpDecoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnUpDecoderBlock2D`, instead.
  deprecate("AttnUpDecoderBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:344: FutureWarning: `AttnSkipUpBlock2D` is deprecated and will be removed in version 0.29. Importing `AttnSkipUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnSkipUpBlock2D`, instead.
  deprecate("AttnSkipUpBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:349: FutureWarning: `SkipUpBlock2D` is deprecated and will be removed in version 0.29. Importing `SkipUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SkipUpBlock2D`, instead.
  deprecate("SkipUpBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:354: FutureWarning: `ResnetUpsampleBlock2D` is deprecated and will be removed in version 0.29. Importing `ResnetUpsampleBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import ResnetUpsampleBlock2D`, instead.
  deprecate("ResnetUpsampleBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:359: FutureWarning: `SimpleCrossAttnUpBlock2D` is deprecated and will be removed in version 0.29. Importing `SimpleCrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SimpleCrossAttnUpBlock2D`, instead.
  deprecate("SimpleCrossAttnUpBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:364: FutureWarning: `KUpBlock2D` is deprecated and will be removed in version 0.29. Importing `KUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KUpBlock2D`, instead.
  deprecate("KUpBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:369: FutureWarning: `KCrossAttnUpBlock2D` is deprecated and will be removed in version 0.29. Importing `KCrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KCrossAttnUpBlock2D`, instead.
  deprecate("KCrossAttnUpBlock2D", "0.29", deprecation_message)
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py:375: FutureWarning: `KAttentionBlock` is deprecated and will be removed in version 0.29. Importing `KAttentionBlock` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KAttentionBlock`, instead.
  deprecate("KAttentionBlock", "0.29", deprecation_message)
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/vedant/repos/soundctm/ctm_inference.py:13 in <module>                                      │
│                                                                                                  │
│    10 import torch                                                                               │
│    11 from accelerate.utils import set_seed                                                      │
│    12 from ctm.inference_sampling import karras_sample                                           │
│ ❱  13 from ctm.script_util import (                                                              │
│    14 │   create_model_and_diffusion,                                                            │
│    15 )                                                                                          │
│    16 from tango_edm.models_edm import build_pretrained_models                                   │
│                                                                                                  │
│ /home/vedant/repos/soundctm/ctm/script_util.py:4 in <module>                                     │
│                                                                                                  │
│     1 import argparse                                                                            │
│     2                                                                                            │
│     3 import numpy as np                                                                         │
│ ❱   4 from tango_edm.models_edm import AudioDiffusionEDM                                         │
│     5                                                                                            │
│     6 from ctm.resample import create_named_schedule_sampler                                     │
│     7                                                                                            │
│                                                                                                  │
│ /home/vedant/repos/soundctm/tango_edm/models_edm.py:11 in <module>                               │
│                                                                                                  │
│     8 from tango_edm.audioldm.variational_autoencoder import AutoencoderKL                       │
│     9 from tango_edm.edm.edm_precond import EDMPrecond, VEPrecond, VPPrecond, iDDPMPrecond       │
│    10 from tango_edm.unet_2d_condition import UNet2DConditionModel as CTMUNet2DConditionModel    │
│ ❱  11 from tango_edm.unet_2d_condition_teacher import UNet2DConditionModel                       │
│    12 from transformers import (                                                                 │
│    13 │   AutoModel,                                                                             │
│    14 │   AutoTokenizer,                                                                         │
│                                                                                                  │
│ /home/vedant/repos/soundctm/tango_edm/unet_2d_condition_teacher.py:20 in <module>                │
│                                                                                                  │
│     17 │   AttnAddedKVProcessor,                                                                 │
│     18 │   AttnProcessor,                                                                        │
│     19 )                                                                                         │
│ ❱   20 from diffusers.models.embeddings import (                                                 │
│     21 │   GaussianFourierProjection,                                                            │
│     22 │   ImageHintTimeEmbedding,                                                               │
│     23 │   ImageProjection,                                                                      │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ImportError: cannot import name 'PositionNet' from 'diffusers.models.embeddings' (/usr/local/lib/python3.8/dist-packages/diffusers/models/embeddings.py)

Appreciate your help, and great work!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.