Coder Social home page Coder Social logo

lucidrains / musiclm-pytorch Goto Github PK

View Code? Open in Web Editor NEW
3.0K 98.0 247.0 201 KB

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch

License: MIT License

Python 100.00%
artificial-intelligence attention-mechanisms deep-learning music-synthesis transformers

musiclm-pytorch's Introduction

MusicLM - Pytorch

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch.

They are basically using text-conditioned AudioLM, but surprisingly with the embeddings from a text-audio contrastive learned model named MuLan. MuLan is what will be built out in this repository, with AudioLM modified from the other repository to support the music generation needs here.

Please join Join us on Discord if you are interested in helping out with the replication with the LAION community

What's AI by Louis Bouchard

Appreciation

Usage

$ pip install musiclm-pytorch

Usage

MuLaN first needs to be trained

import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

# get a ton of <sound, text> pairs and train

wavs = torch.randn(2, 1024)
texts = torch.randint(0, 20000, (2, 256))

loss = mulan(wavs, texts)
loss.backward()

# after much training, you can embed sounds and text into a joint embedding space
# for conditioning the audio LM

embeds = mulan.get_audio_latents(wavs)  # during training

embeds = mulan.get_text_latents(texts)  # during inference

To obtain the conditioning embeddings for the three transformers that are a part of AudioLM, you must use the MuLaNEmbedQuantizer as so

from musiclm_pytorch import MuLaNEmbedQuantizer

# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)

quantizer = MuLaNEmbedQuantizer(
    mulan = mulan,                          # pass in trained mulan from above
    conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024
    namespaces = ('semantic', 'coarse', 'fine')
)

# now say you want the conditioning embeddings for semantic transformer

wavs = torch.randn(2, 1024)
conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers

To train (or finetune) the three transformers that are a part of AudioLM, you simply follow the instructions over at audiolm-pytorch for training, but pass in the MulanEmbedQuantizer instance to the training classes under the keyword audio_conditioner

ex. SemanticTransformerTrainer

import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6,
    audio_text_condition = True      # this must be set to True (same for CoarseTransformer and FineTransformers)
).cuda()

trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    audio_conditioner = quantizer,   # pass in the MulanEmbedQuantizer instance above
    folder ='/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()

After much training on all three transformers (semantic, coarse, fine), you will pass your finetuned or trained-from-scratch AudioLM and MuLaN wrapped in MuLaNEmbedQuantizer to the MusicLM

# you need the trained AudioLM (audio_lm) from above
# with the MulanEmbedQuantizer (mulan_embed_quantizer)

from musiclm_pytorch import MusicLM

musiclm = MusicLM(
    audio_lm = audio_lm,                 # `AudioLM` from https://github.com/lucidrains/audiolm-pytorch
    mulan_embed_quantizer = quantizer    # the `MuLaNEmbedQuantizer` from above
)

music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples = 4) # sample 4 and pick the top match with mulan

Todo

  • mulan seems to be using decoupled contrastive learning, offer that as an option

  • wrap mulan with mulan wrapper and quantize the output, project to audiolm dimensions

  • modify audiolm to accept conditioning embeddings, optionally take care of different dimensions through a separate projection

  • audiolm and mulan goes into musiclm and generate, filter with mulan

  • give dynamic positional bias to self attention in AST

  • implement MusicLM generating multiple samples and selecting top match with MuLaN

  • support variable lengthed audio with masking in audio transformer

  • add a version of mulan to open clip

  • set all the proper spectrogram hyperparameters

Citations

@inproceedings{Agostinelli2023MusicLMGM,
    title     = {MusicLM: Generating Music From Text},
    author    = {Andrea Agostinelli and Timo I. Denk and Zal{\'a}n Borsos and Jesse Engel and Mauro Verzetti and Antoine Caillon and Qingqing Huang and Aren Jansen and Adam Roberts and Marco Tagliasacchi and Matthew Sharifi and Neil Zeghidour and C. Frank},
    year      = {2023}
}
@article{Huang2022MuLanAJ,
    title   = {MuLan: A Joint Embedding of Music Audio and Natural Language},
    author  = {Qingqing Huang and Aren Jansen and Joonseok Lee and Ravi Ganti and Judith Yue Li and Daniel P. W. Ellis},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.12415}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@article{Liu2022PatchDropoutEV,
    title   = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
    author  = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.07220}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{gilmer2023intriguing
    title  = {Intriguing Properties of Transformer Training Instabilities},
    author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
    year   = {2023},
    status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
@inproceedings{Shukor2022EfficientVP,
    title   = {Efficient Vision-Language Pretraining with Visual Concepts and Hierarchical Alignment},
    author  = {Mustafa Shukor and Guillaume Couairon and Matthieu Cord},
    booktitle = {British Machine Vision Conference},
    year    = {2022}
}
@inproceedings{Zhai2023SigmoidLF,
    title   = {Sigmoid Loss for Language Image Pre-Training},
    author  = {Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer},
    year    = {2023}
}

The only truth is music. - Jack Kerouac

Music is the universal language of mankind. - Henry Wadsworth Longfellow

musiclm-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

musiclm-pytorch's Issues

It takes forever to generate music samples

music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples = 4) # sample 4 and pick the top match with mulan

I'm using this example line - taking me 2.5 hours and then it fails on MPS for Mac. Should this be happening? Trying again on CPU and it seems to be a lot quicker for some reason.

What's the hold up here? Is this lazy execution with training waiting until sample generation is triggered? Or is the sample generation actually taking this long?

Usage about grad_accum_every

I am curious about how "grad_accum_every" used in https://github.com/lucidrains/musiclm-pytorch/blob/main/musiclm_pytorch/trainer.py#L317

In my previous experience, the model basically get gradient (backward) once a step. Why should we split loss "grad_accum_every" times to get gradient in a step?

If I have gpu constrain (1 T4 gpu), that means I could only set batch size to 1 or 2 at each stage training, should I still set "grad_accum_every' to large number like 16 or 32?

Thank you!

Error running AudioLM

I am running the audiolm implementation from github and facing error in the following

audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

text = "The sound of a violin playing a sad melody"
generated_wav = audiolm(text=text, batch_size=1)

I have tried changing the dimensions in the transformers, but the issue is still there,

fine_transformer = FineTransformer( num_coarse_quantizers = 3, num_fine_quantizers = 5, codebook_size = 1024, dim = 1024, depth = 6, audio_text_condition=True, # this must be set to True (same for SemanticTransformer and FineTransformer) )

coarse_transformer = CoarseTransformer( num_semantic_tokens = wav2vec.codebook_size, codebook_size = 1024, num_coarse_quantizers = 3, dim = 1024, depth = 6, audio_text_condition=True, # this must be set to True (same for SemanticTransformer and FineTransformer) )

semantic_transformer = SemanticTransformer( num_semantic_tokens = wav2vec.codebook_size, dim = 1024, depth = 6, audio_text_condition = True # this must be set to True (same for CoarseTransformer and FineTransformers) ).cuda()

but I am still get the following error,

AssertionError: you had specified a conditioning dimension of 1024, yet what was received by the transformer has
dimension of 768

Please help, I need to submit this implementation

Error: name 'partial' is not defined (The new release version)

I'm getting this error on the new release of musiclm-pytorch

โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ /home/ramsy/projects/music/main.py:47 in <module>                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   44 โ”‚   โ”‚   sys.exit(1)                                                                         โ”‚
โ”‚   45                                                                                             โ”‚
โ”‚   46 if __name__ == "__main__":                                                                  โ”‚
โ”‚ โฑ 47 โ”‚   CLI()                                                                                   โ”‚
โ”‚   48                                                                                             โ”‚
โ”‚   49                                                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/ramsy/.local/lib/python3.10/site-packages/typer/main.py:214 in __call__                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   211 โ”‚   โ”‚   )                                                                                  โ”‚
โ”‚   212 โ”‚                                                                                          โ”‚
โ”‚   213 โ”‚   def __call__(self, *args: Any, **kwargs: Any) -> Any:                                  โ”‚
โ”‚ โฑ 214 โ”‚   โ”‚   return get_command(self)(*args, **kwargs)                                          โ”‚
โ”‚   215                                                                                            โ”‚
โ”‚   216                                                                                            โ”‚
โ”‚   217 def get_group(typer_instance: Typer) -> click.Command:                                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/ramsy/.local/lib/python3.10/site-packages/click/core.py:829 in __call__                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    826 โ”‚                                                                                         โ”‚
โ”‚    827 โ”‚   def __call__(self, *args, **kwargs):                                                  โ”‚
โ”‚    828 โ”‚   โ”‚   """Alias for :meth:`main`."""                                                     โ”‚
โ”‚ โฑ  829 โ”‚   โ”‚   return self.main(*args, **kwargs)                                                 โ”‚
โ”‚    830                                                                                           โ”‚
โ”‚    831                                                                                           โ”‚
โ”‚    832 class Command(BaseCommand):                                                               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/ramsy/.local/lib/python3.10/site-packages/click/core.py:782 in main                        โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    779 โ”‚   โ”‚   try:                                                                              โ”‚
โ”‚    780 โ”‚   โ”‚   โ”‚   try:                                                                          โ”‚
โ”‚    781 โ”‚   โ”‚   โ”‚   โ”‚   with self.make_context(prog_name, args, **extra) as ctx:                  โ”‚
โ”‚ โฑ  782 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   rv = self.invoke(ctx)                                                 โ”‚
โ”‚    783 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   if not standalone_mode:                                               โ”‚
โ”‚    784 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   return rv                                                         โ”‚
โ”‚    785 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   # it's not safe to `ctx.exit(rv)` here!                               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/ramsy/.local/lib/python3.10/site-packages/click/core.py:1259 in invoke                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1256 โ”‚   โ”‚   โ”‚   โ”‚   Command.invoke(self, ctx)                                                 โ”‚
โ”‚   1257 โ”‚   โ”‚   โ”‚   โ”‚   sub_ctx = cmd.make_context(cmd_name, args, parent=ctx)                    โ”‚
โ”‚   1258 โ”‚   โ”‚   โ”‚   โ”‚   with sub_ctx:                                                             โ”‚
โ”‚ โฑ 1259 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   return _process_result(sub_ctx.command.invoke(sub_ctx))               โ”‚
โ”‚   1260 โ”‚   โ”‚                                                                                     โ”‚
โ”‚   1261 โ”‚   โ”‚   # In chain mode we create the contexts step by step, but after the                โ”‚
โ”‚   1262 โ”‚   โ”‚   # base command has been invoked.  Because at that point we do not                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/ramsy/.local/lib/python3.10/site-packages/click/core.py:1066 in invoke                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1063 โ”‚   โ”‚   """                                                                               โ”‚
โ”‚   1064 โ”‚   โ”‚   _maybe_show_deprecated_notice(self)                                               โ”‚
โ”‚   1065 โ”‚   โ”‚   if self.callback is not None:                                                     โ”‚
โ”‚ โฑ 1066 โ”‚   โ”‚   โ”‚   return ctx.invoke(self.callback, **ctx.params)                                โ”‚
โ”‚   1067                                                                                           โ”‚
โ”‚   1068                                                                                           โ”‚
โ”‚   1069 class MultiCommand(Command):                                                              โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/ramsy/.local/lib/python3.10/site-packages/click/core.py:610 in invoke                      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    607 โ”‚   โ”‚   args = args[2:]                                                                   โ”‚
โ”‚    608 โ”‚   โ”‚   with augment_usage_errors(self):                                                  โ”‚
โ”‚    609 โ”‚   โ”‚   โ”‚   with ctx:                                                                     โ”‚
โ”‚ โฑ  610 โ”‚   โ”‚   โ”‚   โ”‚   return callback(*args, **kwargs)                                          โ”‚
โ”‚    611 โ”‚                                                                                         โ”‚
โ”‚    612 โ”‚   def forward(*args, **kwargs):  # noqa: B902                                           โ”‚
โ”‚    613 โ”‚   โ”‚   """Similar to :meth:`invoke` but fills in default keyword                         โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/ramsy/.local/lib/python3.10/site-packages/typer/main.py:497 in wrapper                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   494 โ”‚   โ”‚   โ”‚   โ”‚   use_params[k] = v                                                          โ”‚
โ”‚   495 โ”‚   โ”‚   if context_param_name:                                                             โ”‚
โ”‚   496 โ”‚   โ”‚   โ”‚   use_params[context_param_name] = click.get_current_context()                   โ”‚
โ”‚ โฑ 497 โ”‚   โ”‚   return callback(**use_params)  # type: ignore                                      โ”‚
โ”‚   498 โ”‚                                                                                          โ”‚
โ”‚   499 โ”‚   update_wrapper(wrapper, callback)                                                      โ”‚
โ”‚   500 โ”‚   return wrapper                                                                         โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/ramsy/projects/music/main.py:34 in train                                                โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   31 โ”‚   โ”‚   Train the model locally in order                                                    โ”‚
โ”‚   32 โ”‚   โ”‚   to use it generate music                                                            โ”‚
โ”‚   33 โ”‚   """                                                                                     โ”‚
โ”‚ โฑ 34 โ”‚   train_model()                                                                           โ”‚
โ”‚   35                                                                                             โ”‚
โ”‚   36 @CLI.command()                                                                              โ”‚
โ”‚   37 def generate(prompt: list[str]) -> None:                                                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/ramsy/projects/music/src/train_model.py:37 in train_model                               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    34 โ”‚   โ”‚   dim_head = 64                                                                      โ”‚
โ”‚    35 โ”‚   )                                                                                      โ”‚
โ”‚    36 โ”‚                                                                                          โ”‚
โ”‚ โฑ  37 โ”‚   mulan = MuLaN(                                                                         โ”‚
โ”‚    38 โ”‚   โ”‚   audio_transformer = AUDIO_TRANSFORMER,                                             โ”‚
โ”‚    39 โ”‚   โ”‚   text_transformer = TEXT_TRANSFORMER                                                โ”‚
โ”‚    40 โ”‚   )                                                                                      โ”‚
โ”‚ <@beartype(musiclm_pytorch.musiclm_pytorch.MuLaN.__init__) at 0x7f7e2d17eef0>:52 in __init__     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/ramsy/.local/lib/python3.10/site-packages/musiclm_pytorch/musiclm_pytorch.py:673 in        โ”‚
โ”‚ __init__                                                                                         โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   670 โ”‚   โ”‚   self.text_to_latents = nn.Linear(self.text.dim, dim_latent)                        โ”‚
โ”‚   671 โ”‚   โ”‚   self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)                      โ”‚
โ”‚   672 โ”‚   โ”‚                                                                                      โ”‚
โ”‚ โฑ 673 โ”‚   โ”‚   klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(Soft   โ”‚
โ”‚   674 โ”‚   โ”‚   self.contrast = klass()                                                            โ”‚
โ”‚   675 โ”‚   โ”‚                                                                                      โ”‚
โ”‚   676 โ”‚   โ”‚   self.multi_layer_contrastive_learning = None                                       โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
NameError: name 'partial' is not defined

and i have looked to the the file /home/ramsy/.local/lib/python3.10/site-packages/musiclm_pytorch/musiclm_pytorch.py i can't seem to find this partial thing

Ex code for CoarseTransformer and FineTransformer

I have been able to successfully train SemanticTransformerTrainer. But getting error with later two.
coarse_transformer = CoarseTransformer(
codebook_size=wav2vec.codebook_size,
num_coarse_quantizers = 8,
num_semantic_tokens = 1000,
dim=1024,
depth=6,
audio_text_condition=True # this must be set to True (same for SemanticTransformer and FineTransformer)
).cuda()

trainer = CoarseTransformerTrainer(
transformer=coarse_transformer,
wav2vec=wav2vec,
audio_conditioner=quantizer, # pass in the MulanEmbedQuantizer instance above
folder='/content/music_data',
soundstream=soundstream, #where to get this from
batch_size=1,
data_max_length=320 * 32,
num_train_steps=1
)`

Exception when attempting to train

i'm excited to try this out!

i attempted to train, feeding in a MockTextAudioDataset similar to the example on AudioLM's page (that worked with the semantic trainer there), but encountered the following exception: TypeError: 'int' object is not iterable

Full stack trace, in case it helps:

File "train_mulan.py", line 60, in
trainer.train()
File "<@beartype(musiclm_pytorch.trainer.MuLaNTrainer.train) at 0x7ff0e221f160>", line 30, in train
File "/mnt/c/audio-ml-workspace/musiclm/musiclm_pytorch/trainer.py", line 363, in train
logs = self.train_step()
File "/mnt/c/audio-ml-workspace/musiclm/musiclm_pytorch/trainer.py", line 330, in train_step
data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))
File "/mnt/c/audio-ml-workspace/musiclm/musiclm_pytorch/trainer.py", line 57, in cycle
for data in dl:
File "/home/qualia/anaconda3/envs/audiolm/lib/python3.8/site-packages/accelerate/data_loader.py", line 375, in iter
current_batch = next(dataloader_iter)
File "/home/qualia/anaconda3/envs/audiolm/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/home/qualia/anaconda3/envs/audiolm/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/qualia/anaconda3/envs/audiolm/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 61, in fetch
return self.collate_fn(data)
File "/mnt/c/audio-ml-workspace/musiclm/musiclm_pytorch/trainer.py", line 146, in inner
output = fn(datum)
File "/mnt/c/audio-ml-workspace/musiclm/musiclm_pytorch/trainer.py", line 156, in curtail_to_shortest_collate
min_len = min(*[datum.shape[0] for datum in data])
TypeError: 'int' object is not iterable

Inference with MuLaN

@lucidrains Somehow i got the MuLaN trained with the MusicCaps dataset. Now i want to check how close the text and wav embeddings are. So while extracting text and wav embeddings using:

audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    max_seq_len = 512
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

mulan.to(device)

model = torch.load('results/mulan.45000.pt')

# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in model.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v

mulan.load_state_dict(new_state_dict)

wavs = torch.randn(2, 1024)
texts = torch.randint(0, 20000, (2, 256))

wav_emb = mulan.get_audio_latents(wavs) 
text_emb = mulan.get_text_latents(texts)

I get the following error:

Traceback (most recent call last):
  File "test_mulan.py", line 140, in <module>
    audio_emb = mulan.get_audio_latents(wavs)
  File "venv/lib/python3.8/site-packages/musiclm_pytorch/musiclm_pytorch.py", line 502, in get_audio_latents
    audio_embeds = self.audio(wavs)
  File "/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/venv/lib/python3.8/site-packages/musiclm_pytorch/musiclm_pytorch.py", line 379, in forward
    x = self.transformer(x, rel_pos_bias = rel_pos_bias)
  File "/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/venv/lib/python3.8/site-packages/musiclm_pytorch/musiclm_pytorch.py", line 217, in forward
    x = attn(x, rel_pos_bias = rel_pos_bias, mask = mask) + x
  File "/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/venv/lib/python3.8/site-packages/musiclm_pytorch/musiclm_pytorch.py", line 163, in forward
    sim = sim + rel_pos_bias
RuntimeError: The size of tensor a (20) must match the size of tensor b (2560) at non-singleton dimension 3

It happens only for audio inputs. Not for text.

Create endless stream of genre

Hi,

Can I use this to create an endless stream per genre? So a steam of "minimal house, "classic music in Bach style" and "reggae music of the Caribbean"?

inference time

Thanks for your great work!
Could you share with me the inference time of generating 10s audio on CPU/GPU?

Tried to run the example of Readme.md, got error related with tensor dimension

I am trying to run the example in Google Collab but I get a Runtime error when running the part for obtaining the conditioning embeddings:

from musiclm_pytorch import MuLaNEmbedQuantizer

# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)

quantizer = MuLaNEmbedQuantizer(
    mulan = mulan,                          # pass in trained mulan from above
    conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024
    namespaces = ('semantic', 'coarse', 'fine')
)

# now say you want the conditioning embeddings for semantic transformer

wavs = torch.randn(2, 1024)
conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers

RuntimeError: The size of tensor a (20) must match the size of tensor b (2560) at non-singleton dimension 3

Runtime Error on CPU

Discussed in #55

Originally posted by sauravp June 9, 2023
I am trying to train this on CPU (on a small dataset) to validate some ideas.

import torch
from musiclm_pytorch import MusicLM, MuLaNTrainer
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer, MuLaNEmbedQuantizer

import random
import numpy as np

device = 'cpu'


audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    max_seq_len = 512
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

mulan.to(device)
mulan.eval()

wavs = torch.randn(5, 1024).to(device)
texts = torch.randint(0, 20000, (5, 512)).to(device)
#print(wavs.shape, texts.shape)

from torch.utils.data import Dataset

class TextAudioDataset(Dataset):
    def __init__(self, wavs, texts):
        super().__init__()
        self.wavs = wavs
        self.texts = texts

    def __len__(self):
        if len(self.wavs) != len(self.texts):
            return -1
        else:
            return len(self.wavs)

    def __getitem__(self, idx):
        return self.wavs[idx], self.texts[idx]


trainer = MuLaNTrainer(
    mulan = mulan,
    dataset = TextAudioDataset(wavs, texts),
    batch_size = 2
)

trainer.to(device)

trainer.train()

I am getting the following error:

RuntimeError: stft input and window must be on the same device but got self on mps:0 and window on cpu

Is there a way to run the entire thing on CPU?

Additional Work

Could you please document any additional work which is needed for this project? I'd like to contribute.

time of training

Thanks for your great work!
I'm interested in audio music generaion, as the big dataset and complex architecture, I wonder how long will the model be trained in general.

checkpoint files?

I dont even know if it makes sense in this context, but would it be possible to release a checkpoint file for a trained model so we can run inference without having to train ourselves?

Getting started

I'm really interested in generating music from text. However, I have poor knowledge of the pytorch framework and ML in general.

I appreciate the README a lot. However it's quite implicit, especially for a newbe like me. Is there a way I can get started ?

How to get the ./hubert/hubert_base_ls960.pt and ./hubert/hubert_base_ls960_L9_km500.bin files ? Where to download audio files ?

Training MuLan

@lucidrains The dataset for training MuLan in their original paper seems to be private. So we need to see other options like: Free Music Archive (FMA) dataset. Where the text part of a sample is a list of strings, like:

['low quality', 'sustained strings melody', 'soft female vocal', 'mellow piano melody', 'sad', 'soulful', 'ballad']

My question is: Which string should we feed to our network along with the audio? One randomly selected string? Or all? Or make pairs of all with the audio to make even more samples?

Controlling the length of output when calling MusicLM model

Using the default settings, it seems that MusicLM will always output a tensor of length 163840. This is a bit of a strange number, as it's not divisible by the standard sample rate of 44100 that it would presumably be trained on.

I've found that it's possible to pass a max_length argument when calling MusicLM, which gets passed to AudioLM. But passing this argument only controls how many semantic tokens are generated - the coarse, fine and output tensor remain the same size.

For now I've hacked a solution together by additionally passing a max_length to the self.coarse.generate() call in audiolm_pytorch:1628, but I'm wondering if this is the correct way to do it.

What's the best way to generate outputs of different lengths with this model?

Support our open source music pretrained Transformer

Hi, we are researchers from the MAP (music audio pre-train) project. We pre-train transformer LMs on large-scale music audio datasets.
See below. Our model, MERT, uses a similar method as HuBERT and has verified its performance on downstream music information retrieval tasks. It has been released on hugging face and can be used interchangeably with HuBERT loading code: model = HubertModel.from_pretrained("m-a-p/MERT-v0")
We are currently working on training a better base model and scaling up to a large model with more music+speech data.
Using our weights as an initialization will be a better start than using speech HuBERT. Better checkpoints will be released soon.

https://huggingface.co/m-a-p/MERT-v0

nan loss in MuLaN training

@lucidrains
While training MuLaN on a dataset of around 5.2k samples, the loss goes to nan after some 15-16k steps.
My batch size is 4, and the text part of the data samples are tokenized using:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
text_in_numbers = tokenizer.encode(text)

Does it has something to do with the zero division? or square-root of 0 in the loss function?

How to include audioset

Hello,

I am fairly new to the whole topic of huge ML models and I'd really like to try this out.
However, I don't quite get a grasp on how to implement the dataset to train the model.
In another issue, someone mentioned Audioset (https://research.google.com/audioset/download.html) which
sounds interesting but apparently "only" offers large csv-files.
How do I implement them in this project where a path to the dataset is required? Or is there any
tutorial I can look this up?

Thank you so much in advance!

Install and use musiclm

Hello,
id like to input a promt and get music fro your ai but im not sure how i can install it

Happy to help

I've been working on some basic music generation lately. I'm very interested in this implementation and would be happy to contribute in any way. Maybe that's as simple as donating to sustain your work, or it could be contributing code. Just want to put it out there that I support what you're doing and I'm very interested to see an implementation of this in PyTorch.

Nameerror in musiclm_pytorch.py

Hi, I'm implenting code in Google Colab. And I find a strange error.

import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

Error code occurs like this.

โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ in <cell line: 1>:1                                                                              โ”‚
โ”‚ in __init__:52                                                                                   โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.9/dist-packages/musiclm_pytorch/musiclm_pytorch.py:673 in __init__        โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   670 โ”‚   โ”‚   self.text_to_latents = nn.Linear(self.text.dim, dim_latent)                        โ”‚
โ”‚   671 โ”‚   โ”‚   self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)                      โ”‚
โ”‚   672 โ”‚   โ”‚                                                                                      โ”‚
โ”‚ โฑ 673 โ”‚   โ”‚   klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(Soft   โ”‚
โ”‚   674 โ”‚   โ”‚   self.contrast = klass()                                                            โ”‚
โ”‚   675 โ”‚   โ”‚                                                                                      โ”‚
โ”‚   676 โ”‚   โ”‚   self.multi_layer_contrastive_learning = None                                       โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
NameError: name 'partial' is not defined

I'm getting a memory error that seems unrealistic (small dataset) so I think I've messed up or there's a bug

Can you help with this when you have a moment please? I'd be much appreciative

this is the error:

Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.3.2\plugins\python-ce\helpers\pydev\pydevd.py", line 1496, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.3.2\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "E:\ML\playingWithMusicLM\main.py", line 114, in <module>
    loss = mulan(wavsTensor, selectedTextsTensor)
  File "C:\Users\user\.conda\envs\playingWithMusicLM\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "<@beartype(musiclm_pytorch.musiclm_pytorch.MuLaN.forward) at 0x24d7011c670>", line 47, in forward
  File "C:\Users\user\.conda\envs\playingWithMusicLM\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 547, in forward
    audio_latents = self.get_audio_latents(wavs)
  File "C:\Users\user\.conda\envs\playingWithMusicLM\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 523, in get_audio_latents
    audio_embeds = self.audio(wavs)
  File "C:\Users\user\.conda\envs\playingWithMusicLM\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\user\.conda\envs\playingWithMusicLM\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 396, in forward
    rel_pos_bias = self.dynamic_pos_bias_mlp(rel_dist.float())
RuntimeError: [enforce fail at ..\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 2250000000 bytes.

I'm just running 10 song/text pairs as tensor params in order to "train mulan"

The code to do so is as follows:

 # get a ton of <sound, text> pairs and train
ids = []
texts = []
with open(musicDescriptiveMetadataFilename, newline='\n') as csvfile:
    rows = csv.reader(csvfile, delimiter=',')
    for rowNumber, row in enumerate(rows):
        if rowNumber > 0:
            id = row[0]
            text = row[5]
            ids.append(id)
            texts.append(text)

wavs = []
selectedTexts = []

audioFileNames = os.listdir(".//ytRips")
for n, id in enumerate(ids):
    if n < 10:
        for audioFileName in audioFileNames:
            if audioFileName.__contains__(id):
                a = read(".\\ytRips\\" + audioFileName)
                a = np.array(a[1], dtype=np.float32)
                try:
                    channels=a.shape[1]
                except:
                    channels=1
                    continue

                samples=a.shape[0]

                if channels==2:
                    a = np.resize(a, (samples,1))

                if samples == 480000:
                    wavs.append(a)
                    selectedTexts.append(numpy.asarray(stringToListOfInts(texts[n]),dtype=np.compat.long))

#resize texts to same size
resizedSelectedTexts=[]
for selectedText in selectedTexts:
    size=selectedText.shape[0]
    if size > 450:
        resizedSelectedTexts.append(numpy.resize(selectedText,(450,1)))
    else:
        tmp=selectedText
        for x in range(10):
            tmp=np.concatenate((tmp, selectedText), axis=0)
        if tmp is not None:
            resizedSelectedTexts.append(numpy.resize(np.stack(tmp, axis=0), (450,1)))

wavsTensor = torch.squeeze(torch.tensor(np.stack(wavs, axis=0),dtype=torch.float32))
selectedTextsTensor= torch.squeeze(torch.tensor(np.stack(resizedSelectedTexts, axis=0),dtype=torch.long))

loss = mulan(wavsTensor, selectedTextsTensor)

Training data

Hi there!
Im trying to train the model using the MusicCaps dataset.

However, on the readme, according to wavs = torch.randn(2, 1024) it looks like the audio tensors are 2x1024 (which makes me think it's requiring stereo audio).
The MusicCaps audio is actually mono.

Im not sure if Im correctly interpreting this. Could you give me a hint here?
Thanks!

Few questions

First of all, it is very interesting project.
Thanks for your work!

So, I'm trying to implement this project step by step on Colab(https://colab.research.google.com/drive/1fkXdwUBw9tDxofj5-us0vuOenuqC7rfZ?usp=sharing).

But there is something bothering me, the below code ran out so fast, like 10 sec.

import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

# get a ton of <sound, text> pairs and train

wavs = torch.randn(2, 1024)
texts = torch.randint(0, 20000, (2, 256))

loss = mulan(wavs, texts)
loss.backward()

With below message.

spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer

Is this normal behavour?

And to create musiclm, it requires audiolm and to create audiolm, do I have to create soundstrem, coarse_transformer, fine_transformer according to here(https://github.com/lucidrains/audiolm-pytorch)?
Or is there another way to achieve it?

audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

musiclm = MusicLM(
    audio_lm = embeds_audio,
    mulan_embed_quantizer = quantizer
)

music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.Tensor

ast pretrained model use

@lucidrains Hi, i want to use pretrained AST(https://github.com/YuanGongND/ast )model in this audiotransformer.
In musiclm-pytorch, the input shape of audio wave is 2 dim(ex. (2,1024)), however, in pretrained AST, we need a 3 dim input(ex. (batchsize, time, frequency)).

i saw a example which maded by someone for applying audiocap dataset to mulan training, but it didn't work because of difference of input dimension.
so, Do you have an idea how the existing dataset(ex. audiocap) to apply to this mulan model?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.