Coder Social home page Coder Social logo

ugness / mebt Goto Github PK

View Code? Open in Web Editor NEW
28.0 6.0 0.0 71.66 MB

official implementation of the paper: Towards End-to-End Generative Modeling of Long Videos with Memory-Efficient Bidirectional Transformers (CVPR 2023)

Home Page: https://sites.google.com/view/mebt-cvpr2023/home

Python 93.70% Shell 6.30%
transformer-architecture video-generation bidirectional-transformer

mebt's Introduction

Towards End-to-End Generative Modeling of Long Videos with Memory-Efficient Bidirectional Transformers (CVPR 2023)

This repository is an official implementation of the paper:
Towards End-to-End Generative Modeling of Long Videos with Memory-Efficient Bidirectional Transformers (CVPR 2023)
Jaehoon Yoo, Semin Kim, Doyup Lee, Chiheon Kim, Seunghoon Hong
Project Page | Paper

Setup

We installed the packages specified in requirements.txt based on this docker image

docker pull pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel
docker run -it --shm-size=24G pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime /bin/bash
git clone https://github.com/Ugness/MeBT
mv MeBT
pip install requirements.txt

Datasets

Download

Data preprocessing & setup

  1. Extract all frames in each video. The filename should be [VIDEO_ID]_[FRAME_NUM].[png, jpg, ...]
  2. create train.txt and test.txt containing the directory of the entire frames. For example,
find $(pwd)/dataset/train -name "*.png" >> 'train.txt'
find $(pwd)/dataset/test -name "*.png" >> 'test.txt'
  1. The txt files should be located as [DATA_PATH]/train.txt and [DATA_PATH]/test.txt.

Checkpoints

  1. VQGAN: Checkpoints for the 3d VQGAN can be found here.
  2. MeBT: Checkpoints for MeBT can be found here

Configuration Files

You may control the experiments with a configuration files. The default configuration files can be found in the configs folder.

Here is an example of the config file.

model:
    target: mebt.transformer.Net2NetTransformer
    params:
        unconditional: True
        vocab_size: 16384 # You should follow the vocab_size of 3d VQGAN.
        first_stage_vocab_size: 16384
        block_size: 1024 # total number of input tokens (output of 3d VQGAN.)
        n_layer: 24 # number of layers for MeBT
        n_head: 16 # number of attention heads
        n_embd: 1024 # hidden dimension
        n_unmasked: 0
        embd_pdrop: 0.1 # Dropout ratio
        resid_pdrop: 0.1 # Dropout ratio
        attn_pdrop: 0.1 # Dropout ratio
        sample_every_n_latent_frames: 0
        first_stage_key: video # ignore
        cond_stage_key: label # ignore
        vtokens: False # ignore
        vtokens_pos: False # ignore
        vis_epoch: 100
        sos_emb: 256 # Number of latent tokens.
        avg_loss: True
        mode: # You may stack different type of layers. The total number of layers should be matched with n_layer
            - latent_enc
            - latent_self
            - latent_enc
            - latent_self
            - latent_enc
            - latent_self
            - latent_enc
            - latent_self
            - latent_enc
            - latent_self
            - latent_enc
            - latent_self
            - latent_enc
            - latent_dec
            - lt2l
            - latent_dec
            - lt2l
            - latent_dec
            - lt2l
            - latent_dec
            - lt2l
            - latent_dec
            - lt2l
            - latent_dec
    mask:
        target: mebt.mask_sampler.MaskGen
        params:
            iid: False
            schedule: linear
            max_token: 1024 # total number of input tokens (output of 3d VQGAN.)
            method: 'mlm'
            shape: [4, 16, 16] # shape of the output of 3d VQGAN. (T, H, W)
            t_range: [0.0, 1.0]
            budget: 1024 # total number of input tokens (output of 3d VQGAN.)

    vqvae:
        params:
            ckpt_path: 'ckpts/vqgan_sky_128_488_epoch=12-step=29999-train.ckpt' # Path to the 3d VQGAN checkpoint.
            ignore_keys: ['loss']

data:
    data_path: 'datasets/vqgan_data/stl_128' # [DATA_PATH]
    sequence_length: 16 # Length of the training video (in frames)
    resolution: 128 # Resolution of the training video (in pixels)
    batch_size: 6 # Batch_size per GPU
    num_workers: 8
    image_channels: 3
    smap_cond: 0
    smap_only: False
    text_cond: False
    vtokens: False
    vtokens_pos: False
    spatial_length: 0
    sample_every_n_frames: 1
    image_folder: True
    stft_data: False

exp:
    exact_lr: 1.08e-05 # learning rate

Training

The scripts for training can be found in scripts folder. You may excute the script as following:

bash scripts/train_config_log_gpus.sh [CONFIG_FILE] [LOG_DIR] [GPU_IDs]
  • [GPU_IDs]:
    • 0, : use GPU_ID 0 only.
    • 0,1,2,3,4,5,6,7 : use 8 GPUs from 0 to 7.

Inference

The scripts for inference can be found in scripts folder. You may excute the script as following:

bash scripts/valid_dnr_config_ckpt_exp_[stl, taichi, ucf]_[16f, 128f].sh [CONFIG_FILE] [CKPT_PATH] [SAVE_DIR]
  • You should change the [DATA_PATH] in the script file to measure FVD and KVD.

Acknowledgements

  • Our code is based on VQGAN and TATS.
  • The development of this open-sourced code was supported in part by the National Research Foundation of Korea (NRF) (No. 2021R1A4A3032834).

Citation

@article{yoo2023mebt,
         title={Towards End-to-End Generative Modeling of Long Videos with Memory-Efficient Bidirectional Transformers},
         author={Jaehoon Yoo, Semin Kim, Doyup Lee, Chiheon Kim, Seunghoon Hong},
         journal={arXiv preprint arXiv:2303.11251},
         year={2023}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.