Coder Social home page Coder Social logo

ugness / mebt Goto Github PK

View Code? Open in Web Editor NEW
28.0 6.0 0.0 71.66 MB

official implementation of the paper: Towards End-to-End Generative Modeling of Long Videos with Memory-Efficient Bidirectional Transformers (CVPR 2023)

Home Page: https://sites.google.com/view/mebt-cvpr2023/home

Python 93.70% Shell 6.30%
transformer-architecture video-generation bidirectional-transformer

mebt's Introduction

Towards End-to-End Generative Modeling of Long Videos with Memory-Efficient Bidirectional Transformers (CVPR 2023)

This repository is an official implementation of the paper:
Towards End-to-End Generative Modeling of Long Videos with Memory-Efficient Bidirectional Transformers (CVPR 2023)
Jaehoon Yoo, Semin Kim, Doyup Lee, Chiheon Kim, Seunghoon Hong
Project Page | Paper

Setup

We installed the packages specified in requirements.txt based on this docker image

docker pull pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel
docker run -it --shm-size=24G pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime /bin/bash
git clone https://github.com/Ugness/MeBT
mv MeBT
pip install requirements.txt

Datasets

Download

Data preprocessing & setup

  1. Extract all frames in each video. The filename should be [VIDEO_ID]_[FRAME_NUM].[png, jpg, ...]
  2. create train.txt and test.txt containing the directory of the entire frames. For example,
find $(pwd)/dataset/train -name "*.png" >> 'train.txt'
find $(pwd)/dataset/test -name "*.png" >> 'test.txt'
  1. The txt files should be located as [DATA_PATH]/train.txt and [DATA_PATH]/test.txt.

Checkpoints

  1. VQGAN: Checkpoints for the 3d VQGAN can be found here.
  2. MeBT: Checkpoints for MeBT can be found here

Configuration Files

You may control the experiments with a configuration files. The default configuration files can be found in the configs folder.

Here is an example of the config file.

model:
    target: mebt.transformer.Net2NetTransformer
    params:
        unconditional: True
        vocab_size: 16384 # You should follow the vocab_size of 3d VQGAN.
        first_stage_vocab_size: 16384
        block_size: 1024 # total number of input tokens (output of 3d VQGAN.)
        n_layer: 24 # number of layers for MeBT
        n_head: 16 # number of attention heads
        n_embd: 1024 # hidden dimension
        n_unmasked: 0
        embd_pdrop: 0.1 # Dropout ratio
        resid_pdrop: 0.1 # Dropout ratio
        attn_pdrop: 0.1 # Dropout ratio
        sample_every_n_latent_frames: 0
        first_stage_key: video # ignore
        cond_stage_key: label # ignore
        vtokens: False # ignore
        vtokens_pos: False # ignore
        vis_epoch: 100
        sos_emb: 256 # Number of latent tokens.
        avg_loss: True
        mode: # You may stack different type of layers. The total number of layers should be matched with n_layer
            - latent_enc
            - latent_self
            - latent_enc
            - latent_self
            - latent_enc
            - latent_self
            - latent_enc
            - latent_self
            - latent_enc
            - latent_self
            - latent_enc
            - latent_self
            - latent_enc
            - latent_dec
            - lt2l
            - latent_dec
            - lt2l
            - latent_dec
            - lt2l
            - latent_dec
            - lt2l
            - latent_dec
            - lt2l
            - latent_dec
    mask:
        target: mebt.mask_sampler.MaskGen
        params:
            iid: False
            schedule: linear
            max_token: 1024 # total number of input tokens (output of 3d VQGAN.)
            method: 'mlm'
            shape: [4, 16, 16] # shape of the output of 3d VQGAN. (T, H, W)
            t_range: [0.0, 1.0]
            budget: 1024 # total number of input tokens (output of 3d VQGAN.)

    vqvae:
        params:
            ckpt_path: 'ckpts/vqgan_sky_128_488_epoch=12-step=29999-train.ckpt' # Path to the 3d VQGAN checkpoint.
            ignore_keys: ['loss']

data:
    data_path: 'datasets/vqgan_data/stl_128' # [DATA_PATH]
    sequence_length: 16 # Length of the training video (in frames)
    resolution: 128 # Resolution of the training video (in pixels)
    batch_size: 6 # Batch_size per GPU
    num_workers: 8
    image_channels: 3
    smap_cond: 0
    smap_only: False
    text_cond: False
    vtokens: False
    vtokens_pos: False
    spatial_length: 0
    sample_every_n_frames: 1
    image_folder: True
    stft_data: False

exp:
    exact_lr: 1.08e-05 # learning rate

Training

The scripts for training can be found in scripts folder. You may excute the script as following:

bash scripts/train_config_log_gpus.sh [CONFIG_FILE] [LOG_DIR] [GPU_IDs]
  • [GPU_IDs]:
    • 0, : use GPU_ID 0 only.
    • 0,1,2,3,4,5,6,7 : use 8 GPUs from 0 to 7.

Inference

The scripts for inference can be found in scripts folder. You may excute the script as following:

bash scripts/valid_dnr_config_ckpt_exp_[stl, taichi, ucf]_[16f, 128f].sh [CONFIG_FILE] [CKPT_PATH] [SAVE_DIR]
  • You should change the [DATA_PATH] in the script file to measure FVD and KVD.

Acknowledgements

  • Our code is based on VQGAN and TATS.
  • The development of this open-sourced code was supported in part by the National Research Foundation of Korea (NRF) (No. 2021R1A4A3032834).

Citation

@article{yoo2023mebt,
         title={Towards End-to-End Generative Modeling of Long Videos with Memory-Efficient Bidirectional Transformers},
         author={Jaehoon Yoo, Semin Kim, Doyup Lee, Chiheon Kim, Seunghoon Hong},
         journal={arXiv preprint arXiv:2303.11251},
         year={2023}
}

mebt's People

Contributors

ugness avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

mebt's Issues

More information

Hey, really cool paper.

Can you update readme with more examples?
Is it FOSS?
Expected date of release?

Where is the code?

Hello, I taskj you to ASAP release
Full code,
Most advanced Models & Checkpoints you got (actually, even the not so good ones u got, better release them!)
Documentantion
You got 6 hours before I get home to start programming. Release them before this @Ugness !
Thanks!
Any questions drop them here.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.