Coder Social home page Coder Social logo

csr's Introduction

Calibrated Self-Rewarding Vision Language Models

Yiyang Zhou*, Zhiyuan Fan*, Dongjie Cheng*, Sihan Yang, Zhaorun Chen, Chenhang Cui, Xiyao Wang, Yun Li, Linjun Zhang, Huaxiu Yao

Hugging Face

[Project page]

Citation: If you find this repo useful for your research, please consider citing the paper

@article{zhou2024calibrated,
  title={Calibrated Self-Rewarding Vision Language Models},
  author={Zhou, Yiyang and Fan, Zhiyuan and Cheng, Dongjie and Yang, Sihan and Chen, Zhaorun and Cui, Chenhang and Wang, Xiyao and Li, Yun and Zhang, Linjun and Yao, Huaxiu},
  journal={arXiv preprint arXiv:2405.14622},
  year={2024}
}

Table of Contents

About CSR


Framework of Calibrated Self-Rewarding (CSR)

Existing methods use additional models or human annotations to curate preference data and enhance modality alignment through preference optimization. These methods are resource-intensive and may not effectively reflect the target LVLM’s preferences, making the curated preference data easily distinguishable. To address these challenges, we proposes the Calibrated Self- Rewarding (CSR), which enables the model to self-improve by iteratively generating candidate responses, evaluating the reward for each response, and curating preference data for fine-tuning. In reward modeling, a step-wise strategy is adopted, and visual constraints are incorporated into the self-rewarding process to emphasize visual input.


Left: Different parameter sizes of LLaVA 1.5 can enhance their learning through CSR iterations. Right: The change in image relevance scores before and after employing CSR.

Through the online CSR process, the model continuously enhances its performance across various benchmarks and improves the overall relevance scores of its responses to visual inputs. Additionally, it reduces the gap between rejected responses and chosen responses, thereby improving the model's performance lower bound.

Installation

The build process based on LLaVA 1.5:

  1. Clone this repository and navigate to LLaVA folder
git clone https://github.com/haotian-liu/LLaVA.git
cd LLaVA
git clone https://github.com/YiyangZhou/CSR.git
  1. Install Package
conda create -n csr python=3.10 -y
conda activate csr
pip install --upgrade pip
pip install -e .
  1. Install additional packages for training cases
pip install -e ".[train]"
pip install flash-attn --no-build-isolation
  1. Install trl package
pip install trl
  1. Modify the TRL library adjust DPO for LVLMs
cd *your conda path*/envs/csr/lib/python3.10/site-packages/trl/trainer/
# Replace dop_trainer.py with dop_trainer.py in the 'train_csr' folder.
  1. Modify the parent class of llava_trainer
cd ./LLaVA/llava/train

# Modify llava_trainer.py as follows:

# from trl import DPOTrainer
# ...
# ...
# ...
# class LLaVATrainer(DPOTrainer):

Instruction

Before starting, you need to:

(1) modify the path in './CSR/scripts/run_train.sh' to your own path.

(2) If you are using wandb, you need to enter your key in './CSR/train_csr/train_dpo_lora.py' by filling in 'wandb.login(key="your key")' with your key.

(3) Download the image data from the COCO website into './data/images/'(or you can prepare your own images and prompt data).

Step 1. Construct Preference Data.

First, prepare the COCO-2014 train images in the './data/images/'. Then complete the following steps in sequence.

cd ./CSR/inference_csr
bash ./step1.sh
bash ./step2.sh
bash ./step3.sh

You now have the preference dataset. This process takes a long time. We provide our preference datasets in huggingface.

Step 2. Direct Preference Optimization (DPO).

bash ./CSR/scripts/run_train.sh

Step 3. Iterative Learning.

After completing a round of CSR training, you need to merge the current LoRA checkpoint. Use the merged checkpoint as the base model and proceed with Step 1 and Step 2 sequentially.

python ./scripts/merge_lora_weights.py --model-path "your LoRA checkpoint path" --model-base "your llava 1.5 checkpoint path --> your Iter-1 path --> your Iter-2 path ...." --save-model-path "xxx"

Data and Models

We provide CSR training data and model weights on HuggingFace. Please refer to the Instruction for usage.

Dataset Download Model (7B) Download Model (13B) Download
CSR-iter0 🤗 HuggingFace CSR-7B-iter1 🤗 HuggingFace CSR-13B-iter1 🤗 HuggingFace
CSR-iter1 🤗 HuggingFace CSR-7B-iter2 🤗 HuggingFace CSR-13B-iter2 🤗 HuggingFace
CSR-iter2 🤗 HuggingFace CSR-7B-iter3 🤗 HuggingFace CSR-13B-iter3 🤗 HuggingFace

The prompt dataset and mapping files between llava and hf-llava are available in './CSR/inference_csr/data'.

Evaluation

Here are two convenient ways to perform evaluations:

  1. Use the eval scripts provided in LLaVA.

  2. Utilize lmms-eval, a general evaluation platform.

  3. CHAIR metrics in LURE.

Acknowledgement

  • This repository is built upon LLaVA!
  • We thank the Center for AI Safety for supporting our computing needs. This research was supported by Cisco Faculty Research Award.

csr's People

Contributors

yiyangzhou avatar dongjie-cheng avatar

Stargazers

 avatar ChoiDaewon avatar Bingchen Zhao avatar  avatar baeseongsu avatar Kyungmin Jeon avatar 이루리 avatar Xue Jiang avatar  avatar Rohan Wadhawan avatar Huaxiu Yao avatar hcwei avatar Tianyi Xiong avatar JIMMY ZHAO avatar  avatar HaiKu avatar zhang avatar Xiaodong Wang avatar Ritchie avatar Rui Shao avatar YOLO avatar Peng(Richard) Xia avatar  avatar 爱可可-爱生活 avatar Xu Keke avatar seilk avatar Oscar Mañas avatar Junyang Wang avatar Vibashan VS avatar  avatar Zhaorun Chen avatar  avatar

Watchers

Shilin Xu avatar  avatar

csr's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.