Coder Social home page Coder Social logo

conica's Introduction

CONICA: A Contrastive Image Captioning Framework with Robust Similarity Learning

This is the code implementation for the paper titled: "CONICA: A Contrastive Image Captioning Framework with Robust Similarity Learning" (Accepted to ACM MM 2023)

Requirements

  • python>=3.8
  • pytorch=1.11.0 & torchvision=0.12.0
  • transformers=4.29.1
  • tokenizers=0.13.3
  • clip=1.0
  • other packages: pycocotools, pycocoevalcap,tqdm,pandas,tensorboard and timm

Useage

1.Preparation

Features

python prepare/prepro_feats.py -model_name ViT-L/14@336px -input_resolution 336 -dataset “your path to dataset(mscoco or others)” 

Dataset

python prepare/prepro_datasets.py -karpathy_split_json “your path to karpathy split” -output_file /dataset/mscoco.csv

You can donwload karpathy split json from: this link

HFConfiguration & Tokenizers

 python prepare/prepro_conf_tokenizer.py -dataset_path /dataset/mscoco.csv -output_file conica-clip

Words Frequency

python prepare/prepro_ngrams.py -input_csv /dataset/caption/mscoco.csv -output_pkl /dataset/caption/cache_document_frequency/coco-train-words.p

2. Training

2.1. XE stage

python train.py \
--config_name conica-clip \
--max_features_len 256 \
--feature_path /dataset/caption/mscoco/features/ViT-L/14@336px \
--output_dir output/clip-vit_xe/checkpoints \
--do_train \
--evaluation_strategy epoch \
--logging_strategy steps \
--logging_steps 100 \
--logging_dir output/clip-vit_xe/logs \
--save_strategy epoch \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 32 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-4 \
--weight_decay 1e-2 \
--num_train_epochs 30 \
--lr_scheduler_type linear \
--warmup_ratio 0.1 \
--fp16 \
--gradient_checkpointing \
--dataloader_pin_memory true \
--dataloader_num_workers 8

Then choosing the checkpoint with the highest cider for RL stage training.

Or you can download the XE checkpoint from this Google Drive Link

2.2. RL stage

python  train.py \
--config_name conica-clip \
--max_features_len 256 \
--feature_path /dataset/caption/mscoco/features/ViT-L/14@336px \
--output_dir output/clip-vit_rl/checkpoints \
--scst \
--init_tau \
--scst_num_sample_sequences 5 \
--do_train \
--evaluation_strategy epoch \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 32 \
--resume_from_checkpoint "xe-checkpoint" \
--gradient_accumulation_steps 2 \
--learning_rate 5e-6 \
--weight_decay 1e-2 \
--num_train_epochs 20 \
--lr_scheduler_type constant \
--warmup_steps 0 \
--logging_strategy steps \
--logging_steps 100 \
--logging_dir output/clip-vit_rl/logs \
--save_strategy epoch \
--fp16 \
--dataloader_num_workers 12 \
--dataloader_pin_memory

You can download the RL checkpoint from this Google Drive Link

3. Test

python  train.py \
--config_name conica-clip \
--feature_path /dataset/caption/mscoco/features/ViT-L/14@336px \
--output_dir predict_clip \
--resume_from_checkpoint "rl-checkpoint"  \
--do_predict \
--per_device_eval_batch_size 64 \
--gradient_accumulation_steps 1 \
--dataloader_num_workers 4

Citations

Please consider citing this paper if you use this code

@inproceedings{
author = {Deng, Lin and Zhong, Yuzhong and Wang, Maoning  and Zhang,Jianwei},
title = {CONICA: A Contrastive Image Captioning Framework with Robust Similarity Learning},
year = {2023},
booktitle = {Proceedings of the 31st ACM International Conference on Multimedia},
pages = {5109-5119}
}

conica's People

Contributors

denglingo avatar

Watchers

 avatar

conica's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.