Coder Social home page Coder Social logo

xiaoxiaoheimei / seqdialn Goto Github PK

View Code? Open in Web Editor NEW
6.0 2.0 1.0 78 KB

Code for reproducing results in our paper SeqDialN: Sequential Visual Dialog Networks in Joint Visual-Linguistic Representation Space.

License: BSD 3-Clause "New" or "Revised" License

Python 100.00%
visual-dialog dense-coattn-network multimodal-deep-learning pytorch transformer

seqdialn's Introduction

SeqDialN

Code for reproducing results in our paper SeqDialN: Sequential Visual Dialog Networks in Joint Visual-Linguistic Representation Space.

If you find this work is useful in your research, please kindly consider cite our paper:

@inproceedings{yang-etal-2021-seqdialn,
    title = "{S}eq{D}ial{N}: Sequential Visual Dialog Network in Joint Visual-Linguistic Representation Space",
    author = "Yang, Liu  and
      Meng, Fanqi  and
      Liu, Xiao  and
      Wu, Ming-Kuang Daniel  and
      Ying, Vicent  and
      Xu, James",
    booktitle = "Proceedings of the 1st Workshop on Document-grounded Dialogue and Conversational Question Answering (DialDoc 2021)",
    month = aug,
    year = "2021",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2021.dialdoc-1.2",
    doi = "10.18653/v1/2021.dialdoc-1.2",
    pages = "8--17"
}

Setup and Dependencies

This code is implemented using PyTorch v1.2, and we recommend using Anaconda or Miniconda to setup the environment.

  1. Install Anaconda or Miniconda distribution based on Python3+ from their downloads' site.

  2. Create a environment named visdial and install all dependencies with the environment.yml file.

conda env create -f environment.yml
conda activate visdial

Download Data

  1. Download the VisDial v1.0 dialog json files: training set, validation set and test set.

  2. Download the dense annotations for validation set and training subset.

  3. Download the word counts for VisDial v1.0 train split here. They are used to build the vocabulary.

  4. Download pre-trained GloVe word vectors form here and unzip it.

  5. Download pre-extracted image features of VisDial v1.0 images, using a Faster-RCNN pre-trained on Visual Genome. Extracted features for v1.0 train, val and test are available for download at these links.

  1. Check all the files we needed and their location should as follow for default arguments to work effectively:
$PROJECT_ROOT/data/glove.840B.300d/glove.840B.300d.txt
$PROJECT_ROOT/data/features_faster_rcnn_x101_test.h5
$PROJECT_ROOT/data/features_faster_rcnn_x101_train.h5
$PROJECT_ROOT/data/features_faster_rcnn_x101_val.h5
$PROJECT_ROOT/data/visdial_1.0_test.json
$PROJECT_ROOT/data/visdial_1.0_train_dense_sample.json
$PROJECT_ROOT/data/visdial_1.0_train.json
$PROJECT_ROOT/data/visdial_1.0_val_dense_annotations.json
$PROJECT_ROOT/data/visdial_1.0_val.json
$PROJECT_ROOT/data/visdial_1.0_word_counts_train.json

Preprocess Data

Generate token counts for DistillBERT

In order to use DistilBERT embeddiing, we should generate bert token counts json file used to build the vocabulary:

python scripts/gen_bert_token_count.py

The output json file should appear as $PROJECT_ROOT/data/visdial_1.0_bert_token_count.json.

Extract training subset with adjusted dense annotations

Only subset of the training set have dense annotations. We should extract these subset and adjust the gt_relevance values to fine-tune with dense annotations using our re-weight method.

python scripts/extract_train.py --adjust-gt-relevance

Two new json files will appear as $PROJECT_ROOT/data/visdial_1.0_train_dense_sub.json and $PROJECT_ROOT/data/visdial_1.0_train_dense_sample_adjusted.json

Training

Base model training

To train the base model (no finetuning on dense annotations):

python train.py \
  --config-yml configs/disc_mrn_be.yml \
  --gpu-ids 0 \
  --cpu-workers 8 \
  --validate \
  --save-dirpath checkpoints/disc_mrn_be/

Train different type of models by passing different configuration file path to --config-yml. The model type and corresponding configuration files as show in the tables below:

Discriminative:

SeqIPN-GE-D SeqIPN-BE-D SeqMRN-GE-D SeqMRN-BE-D
disc_ipn_ge.yml disc_ipn_be.yml disc_mrn_ge.yml disc_mrn_be.yml

Generative:

SeqIPN-GE-G SeqIPN-BE-G SeqMRN-GE-G SeqMRN-BE-G
gen_ipn_ge.yml gen_ipn_be.yml gen_mrn_ge.yml gen_mrn_be.yml

Provide more ids to --gpu-ids to use multi-GPU execution. For example --gpu-ids 0 1 2 3 will use 4 GPUs to train the model.

Saving model checkpoints

This script will save model checkpoints at every epoch as per path specified by --save-dirpath.

Logging

We use Tensorboard for logging training progress. Execute

tensorboard --logdir checkpoints/ --port 8008

and visit localhost:8008 in the browser.

Fine-tune with dense annotations

To fine-tune the base model with dense annotations:

python train_stage2.py \
  --config-yml configs/disc_mrn_be_ft.yml \
  --gpu-ids 0 \
  --cpu-workers 8 \
  --validate \
  --load-pthpath checkpoints/disc_mrn_be/checkpoint_12.pth \
  --save-dirpath checkpoints/disc_mrn_be_ft/

You should specify the corresponding base model checkpoint path with --load-pthpath.

Both discriminative and generative base model could fine-tune with dense annotations, but fine-tuning can't help generative model boost NDCG much. Model type and corresponding fine-tune configuration files as show in the tables below:

Discriminative:

SeqIPN-GE-D SeqIPN-BE-D SeqMRN-GE-D SeqMRN-BE-D
disc_ipn_ge_ft.yml disc_ipn_be_ft.yml disc_mrn_ge_ft.yml disc_mrn_be_ft.yml

Generative:

SeqIPN-GE-G SeqIPN-BE-G SeqMRN-GE-G SeqMRN-BE-G
gen_ipn_ge_ft.yml gen_ipn_be_ft.yml gen_mrn_ge_ft.yml gen_mrn_be_ft.yml

Evaluation

Evaluation of a trained model checkpoint can be done as follows:

python evaluate.py \
  --config-yml checkpoints/disc_mrn_be_ft/config.yml \
  --split val \
  --gpu-ids 0 \
  --cpu-workers 8 \
  --load-pthpath checkpoints/disc_mrn_be_ft/checkpoint_2.pth \
  --save-ranks-path results/ranks/disc_mrn_be_ft.json \
  --save-preds-path results/preds/disc_mrn_be_ft_preds.h5

This will report metrics form the Visual Dialog paper: R@{1, 5, 10}, Mean rank (mean), Mean reciprocal rank (MRR) and Normalized Discounted Cumulative Gain (NDCG).

If --save-ranks-path was specified, it will generate an EvalAI submission json file.

If --save-preds-path was specified, it will save the model's raw predict scores to a .h5 file which could be used to ensemble models.

The metrics reported here would be the same as those reported through EvalAI by making a submission in val phase.

To generate a submission file or raw predict results .h5 file for test-std or test-challenge phase, replace --split val with --split test.

Ensemble

In order to ensemble several models' predict resuls, you should evaluate these models seperatly using the evaluate.py script mentioned above and save each model's raw predict scores to a .h5 file. Make sure these .h5 file belong to same split (val or test) and in the same folder. Then:

python ensemble.py \
  --preds-folder results/preds/ \
  --split val \
  --method sa \
  --norm-order none \
  --save-ranks-path results/ranks/ensmble.json

This will search all .h5 files in the folder specified in --preds-folder and ensemble the results using method specified in --method. Four ensemble method (sa for "Score Average", pa for "Probability Average", ra for "Rank Average" and rra for "Reciprocal Rank Average") support now.

--norm-order shold be none or a int number and this argument only work when --method is sa. When it is none we average different model's predict scores directly. When it is a int number we normlize the predict scores before average. For example, if --norm-order is 2, we will normlize the predict scores using L2Norm before average.

For val split it will report all metrics mentioned above and --save-ranks-path is optinal. For test split you have to specify --save-ranks-path to save ensembled predict ranks to a json file.

Acknowledgements

This code began as a fork of batra-mlp-lab/visdial-challenge-starter-pytorch.

seqdialn's People

Contributors

farrell0828 avatar xiaoxiaoheimei avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

kelikeli

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.