Coder Social home page Coder Social logo

minicpm-dense-retriever's Introduction

Language-Model-STS-CFT

[Paper Coming Soon...] [Hugging Face ๐Ÿค—]

This project aims to improve text embedding of smaller Language Models (LMs) up to 2B parameters using the contrastive fine-tuning technique. Specifically, the InfoNCE loss is utilized as a training objective.

$$\min - \log \frac{e^{\text{sim}(\textbf{h}_i, \textbf{h}_i^+) / \tau}}{\sum_i \left( e^{\text{sim}(\textbf{h}_i, \textbf{h}_j^+) / \tau }+ e^{\text{sim}(\textbf{h}_i, \textbf{h}_j^-) / \tau} \right)}$$

where $\textbf{h}_i$ denotes an embedding vector of a premise $x_i$, $\tau$ denotes a temperature and $\text{sim}(\textbf{h}_i, \textbf{h}_i^+)$ computes the cosine similarity between embedding vectors $\textbf{h}_i$ and $\textbf{h}_i^+$.

We employ LoRA as our parameter-efficient fine-tuning technique in order to reduce the memory requirement.

Embedding Extraction

  • Every prompt will be appended by the [EOS] token.
  • The embedding vector will be extracted from hidden states at the last layer of this [EOS] token.

Fine-tuned Weights

We have fine-tuned 3 models and we provide their LoRA adapter weights in this Hugging Face ๐Ÿค— collection.

The base models consist of

  1. MiniCPM-2B-dpo-bf16
  2. Gemma-2B-it
  3. Phi-2

The performance and fine-tuning details can be seen in the Hugging Face model page.

Dataset

We utilize the processed NLI dataset as our fine-tuning dataset. The dataset consists of 275K triplets of anchors, their corresponding entailments along with hard negatives. Please follow this README to see how to download the dataset.

Fine-tuning with your own resources

If you are willing to fine-tune the LMs with your own resources, we've provided the code for you. Our code can work with multi-GPUs settings. The more GPUs you have, the larger batch size you can fine-tune.

First, you need to setup the virtual environment. We provided the environment setup file you you.

conda env create --file environment.yml
conda activate cft

Then, download the processed NLI dataset following this README

After that, please follow this README for the fine-tuning steps.

Footnote

This work is the final project of the Natural Language Processing Spring 2024 course at Tsinghua University ๐ŸŸฃ. We would like to express our sincere gratitude to this course !

minicpm-dense-retriever's People

Contributors

trapoom555 avatar

Stargazers

Yotam avatar Zhi Cheng Lee avatar  avatar  avatar  avatar

Watchers

Kostas Georgiou avatar  avatar Yotam avatar  avatar

minicpm-dense-retriever's Issues

Enlarge Batch Size Per GPU

The result of training with batch_size_per_gpu = 2 is unstable.

image
  • The key success of contrastive learning heavily depends on batch size
  • batch_size_per_gpu = 2 has already consumed ~45% of GPU memory (RTX3090)
  • Now using a normal DDP strategy + bf16 Mixed Precision + LoRA for training

STS Evaluation

  • STS Benchmark (validation set) can be used for evaluation during training
  • This benchmark can imply how well the model create an embedding vector given a sentence
  • Score will be calculated using Spearman Correlation

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.