Coder Social home page Coder Social logo

shoryasethia / conversationsummarizerllm Goto Github PK

View Code? Open in Web Editor NEW
1.0 1.0 0.0 38.99 MB

Fine Tuning pegasus and flan-t5 pre-trained language model on dialogsum datasets for conversation summarization to to optimize context window in RAG-LLMs

License: MIT License

Jupyter Notebook 100.00%
dialogue-summarization fine-tuning flan-t5 lora peft peft-fine-tuning-llm pegasus

conversationsummarizerllm's Introduction

Clone Repo

git clone https://github.com/shoryasethia/ConversationSummarizerLLM

AIM : Conversation Summarization for Context Window Optimization in RAG-LLMs

Models like FLAN-T5, Pegasus when fine-tuned on specific datasets like DialogSum, are specialized to handle summaries of conversations. This approach can indeed be highly useful in scenarios where managing the context window size is crucial, such as with Retrieval-Augmented Generation (RAG) models and other language models with limited context windows.

This repo is simpler execution of the ideas/methods mentioned in this OpenAI's blog

Context Window Optimization with Summarization By summarizing conversations, one can effectively reduce the amount of text that needs to be fed into the context window of the language model. This not only helps in managing the limited context size but also ensures that the most relevant information is retained. Here’s how this approach can be beneficial and implemented:

  • Context Window Constraints: Most large language models, including GPTs, Gemini, BERT etc;, have a maximum token limit of few 1000 tokens for the context window. Summarizing previous conversations allows you to fit more relevant content within this limit.

  • Enhanced Focus: Summarization distills the key points of a conversation, enabling the model to focus on the most important information without being overwhelmed by less relevant details.

  • Memory Efficiency: By reducing the amount of text, we can make more efficient use of the model's memory, which can lead to faster processing and reduced computational load.

Implementation Strategy

  • Summarization Model: Fine-tune a summarization model like Flan-T5 on the conversation or dialogue based datasets like DialogSum dataset.

Integrate Summarization with RAG:

  • Summarize Historical Conversations: Before feeding past conversations into the context window, use the summarization model to generate concise summaries.
  • Feed Summarized Conversations: Integrate these summaries into the context window along with the most recent interaction to maintain continuity.

This method can be particularly beneficial for applications involving long-term user interactions, such as customer support, personal assistants, and other conversational AI systems. By focusing on summarization, you can effectively manage and enhance the context provided to the model, leading to more coherent and contextually appropriate responses.

For this I Fine Tuned google/pegasus-xsum and google/flan-t5-base pre-trained LM on knkarthick/dialogsum dataset

Load Dataset

huggingface_dataset_name = "knkarthick/dialogsum"

dataset = load_dataset(huggingface_dataset_name)

or run

git clone https://huggingface.co/datasets/knkarthick/dialogsum

Load Pre-Trained Model

model_name = 'google/pegasus-xsum'

original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype = torch.bfloat16)
model_name = 'google/flan-t5-base'

original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype = torch.bfloat16)

PEFT (Parameter Efficient Fine Tuning) - LoRA

LoRA Cofigurations I used for Fine Tuning

To get target modules, run following

print(original_model)

Output for google/pegasus-xsum

PegasusForConditionalGeneration(
  (model): PegasusModel(
    (shared): Embedding(96103, 1024, padding_idx=0)
    (encoder): PegasusEncoder(
      (embed_tokens): Embedding(96103, 1024, padding_idx=0)
      (embed_positions): PegasusSinusoidalPositionalEmbedding(512, 1024)
      (layers): ModuleList(
        (0-15): 16 x PegasusEncoderLayer(
          (self_attn): PegasusAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): ReLU()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        )
      )
      (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): PegasusDecoder(
      (embed_tokens): Embedding(96103, 1024, padding_idx=0)
      (embed_positions): PegasusSinusoidalPositionalEmbedding(512, 1024)
      (layers): ModuleList(
        (0-15): 16 x PegasusDecoderLayer(
          (self_attn): PegasusAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): PegasusAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        )
      )
      (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
  )
  (lm_head): Linear(in_features=1024, out_features=96103, bias=False)
)

I have just used q_proj and v_proj, and paramters of k_proj, other dense layer or any other parameters are kept freezed

lora_config = LoraConfig(
    r=32, #Rank
    lora_alpha=32,
    target_modules=['q_proj',
                    'v_proj',],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type=TaskType.SEQ_2_SEQ_LM   #pegasus
)

Results

Model Rouge1 Bleu Notebook FT Checkpoints
pegasus-xsum 0.1960458975298376 0.39602403633577834 - -
pegasus-peft-lora-dialogsum 0.25098017644738956 0.43713359476346771 .ipynb , Google Colab πŸ”—
flan-t5-base 0.37651698036006545 0.1550404559570942 - -
flan-t5-peft-lora-dialogsum 0.42925850100353405 0.19056854347856564 .ipynb , Google Colab πŸ”—

Fine tuning Pegasus : Absolute percentage improvement of PEFT MODEL over ORIGINAL MODEL

  • rouge1: 28.5065721082448025%
  • rouge2: 1.3583623289828225%
  • rougeL: 2.7046748437383954%
  • rougeLsum: 9.6404848848455602%

Fine tuning Flan-T5 : Absolute percentage improvement of PEFT MODEL over ORIGINAL MODEL

  • rouge1: 4.72584793565314%
  • rouge2: 1.13980578970623214%
  • rougeL: 1.631214528504942%
  • rougeLsum: 2.5820780787278435%

Reference/Proof of these values : https://github.com/shoryasethia/ConversationSummarizerLLM/tree/main/output

If there was anything that helped you, then do hit a like.

conversationsummarizerllm's People

Contributors

shoryasethia avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.