Coder Social home page Coder Social logo

satwikkottur / visdial-bert Goto Github PK

View Code? Open in Web Editor NEW

This project forked from vmurahari3/visdial-bert

0.0 1.0 0.0 816 KB

Implementation for "Large-scale Pretraining for Visual Dialog" https://arxiv.org/abs/1912.02379

License: BSD 3-Clause "New" or "Revised" License

Python 98.87% Shell 1.13%

visdial-bert's Introduction

VisDial-BERT

PyTorch implementation for the paper:

Large-scale Pretraining for Visual Dialog: A Simple State-of-the-Art Baseline
Vishvak Murahari, Dhruv Batra, Devi Parikh, Abhishek Das

Prior work in visual dialog has focused on training deep neural models on the VisDial dataset in isolation, which has led to great progress, but is limiting and wasteful. In this work, following recent trends in representation learning for language, we introduce an approach to leverage pretraining on related large-scale vision-language datasets before transferring to visual dialog. Specifically, we adapt the recently proposed ViLBERT model for multi-turn visually-grounded conversation sequences. Our model is pretrained on the Conceptual Captions and Visual Question Answering datasets, and finetuned on VisDial with a VisDial-specific input representation and the masked language modeling and next sentence prediction objectives (as in BERT). Our best single model achieves state-of-the-art on Visual Dialog, outperforming prior published work (including model ensembles) by more than 1% absolute on NDCG and MRR.

models

This repository contains code for reproducing results with and without finetuning on dense annotations. All results are on v1.0 of the Visual Dialog dataset. We provide pretrained model weights and associated configs to run inference or train these models from scratch.

If you find this work useful in your research, please cite:

@article{visdial_bert
  title={Large-scale Pretraining for Visual Dialog: A Simple State-of-the-Art Baseline},
  author={Vishvak Murahari and Dhruv Batra and Devi Parikh and Abhishek Das},
  journal={arXiv preprint arXiv:1912.02379},
  year={2019},
}

Table of Contents

Setup and Dependencies

Our code is implemented in PyTorch (v1.0). To setup, do the following:

  1. Install Python 3.6
  2. Get the source:
git clone https://github.com/vmurahari3/visdial-bert.git visdial-bert
  1. Install requirements into the visdial-bert virtual environment, using Anaconda:
conda env create -f env.yml

Usage

Make both the scripts in scripts/ executable

chmod +x scripts/download_preprocessed.sh
chmod +x scripts/download_checkpoints.sh

Download preprocessed data

Download preprocessed dataset and extracted features:

sh scripts/download_preprocessed.sh

To get these files from scratch:

python preprocessing/pre_process_visdial.py 

However, we recommend downloading these files directly.

Pre-trained checkpoints

Download pre-trained checkpoints:

sh scripts/download_checkpoints.sh

Training

After running the above scripts, all the pre-processed data is downloaded to data/visdial and the major pre-trained model checkpoints used in the paper are downloaded to checkpoints-release

Here we list the training arguments to train the important variants in the paper.

To train the base model (no finetuning on dense annotations):

python train.py -batch_size 80  -batch_multiply 1 -lr 2e-5 -image_lr 2e-5 -mask_prob 0.1 -sequences_per_image 2 -start_path checkpoints-release/vqa_pretrained_weights

To finetune the base model with dense annotations:

python dense_annotation_finetuning.py -batch_size 80 -batch_multiply 10  -lr 1e-4 -image_lr 1e-4 -nsp_loss_coeff 0 -mask_prob 0.1 -sequences_per_image 2 -start_path checkpoints-release/basemodel

To finetune the base model with dense annotations and the next sentence prediction (NSP) loss:

python dense_annotation_finetuning.py -batch_size 80 -batch_multiply 10  -lr 1e-4 -image_lr 1e-4 -nsp_loss_coeff 1 -mask_prob 0.1 -sequences_per_image 2 -start_path checkpoints-release/basemodel

NOTE: Dense annotation finetuning is currently only supported for 8-GPU training. This is primarily due to memory issues. To calculate the cross entropy loss over the 100 options at a dialog round, we need to have all the 100 dialog sequences in memory. However, we can only fit 80 sequences on 8 GPUs with ~12 GB RAM and we only select 80 options. Performance gets worse with fewer GPUs as we need to further cut down on the number of answer options.

Evaluation

The below code snippet generates a prediction file which can be submitted to the test server to get results on the test split.

python evaluate.py -n_gpus 8 -start_path <path to model> -save_name <name of model>

The metrics for the pretrained checkpoints should match with the numbers mentioned in the paper. However, we mention them below too. These results are on v1.0 test-std.

Checkpoint Mean Rank MRR R1 R5 R10 NDCG
basemodel 3.32 67.50 53.85 84.68 93.25 63.87
basemodel + dense 6.28 50.74 37.95 64.13 80.00 74.47
basemodel + dense + nsp 4.28 63.92 50.78 79.53 89.60 68.08

Logging

We use Visdom for all logging. Specify visdom_server, visdom_port and enable_visdom arguments in options.py to use this feature.

Visualizing Results

Coming soon

Acknowledgements

Builds on Jiasen Lu's ViLBERT implementation.

License

BSD

visdial-bert's People

Contributors

abhshkdz avatar vmurahari3 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.