Coder Social home page Coder Social logo

umiuni-community / llm-medical-finetuning Goto Github PK

View Code? Open in Web Editor NEW

This project forked from shekswess/llm-medical-finetuning

1.0 0.0 0.0 65.11 MB

A code repository that cointains all the code for finetuning some of the popular LLMs on medical data

Python 5.07% Jupyter Notebook 94.93%

llm-medical-finetuning's Introduction

LLM-Medical-Finetuning

This repository contains all the code necessary to finetune(PEFT using LoRA/QLoRa) the most popular 7B/8B parameters instruct/chat LLMs(Mistral, Llama2, Llama3, Gemma), specifically on medical data by utilizing. The code repository is based on two parts:

  • preparing the instruct medical datasets
  • finetuning the instruct LLMs on the prepared datasets

Preparing the datasets

For this showcase project, two datasets are used:

Medical meadow wikidoc

The Medical Meadow Wikidoc dataset comprises question-answer pairs sourced from WikiDoc, an online platform where medical professionals collaboratively contribute and share contemporary medical knowledge. WikiDoc features two primary sections: the "Living Textbook" and "Patient Information". The "Living Textbook" encompasses chapters across various medical specialties, from which we extracted content. Utilizing GTP-3.5-Turbo, the paragraph headings are transformed into questions and utilized the respective paragraphs as answers. Notably, the structure of "Patient Information" is distinct; each section's subheading already serves as a question, eliminating the necessity for rephrasing.

Medquad

MedQuAD is a comprehensive collection consisting of 47,457 medical question-answer pairs compiled from 12 authoritative sources within the National Institutes of Health (NIH), including domains like cancer.gov, niddk.nih.gov, GARD, and MedlinePlus Health Topics. These question-answer pairs span 37 distinct question types, covering a wide spectrum of medical subjects, including diseases, drugs, and medical procedures. The dataset features additional annotations provided in XML files, facilitating various Information Retrieval (IR) and Natural Language Processing (NLP) tasks. These annotations encompass crucial information such as question type, question focus, synonyms, Unique Identifier (CUI) from the Unified Medical Language System (UMLS), and Semantic Type. Moreover, the dataset includes categorization of question focuses into three main categories: Disease, Drug, or Other, with the exception of collections from MedlinePlus, which exclusively focus on diseases.

For our experiments there are 12 different versions of the datasets, available as Hugging Face datasets:

Finetuning the LLMs

The fine-tuning of the LLMs is based around PEFT(Parameter Efficient FineTuning - Supervised Tuning) using LoRA/QLoRA. Because the resources on Google Colab are limited(T4 GPU), sparing resources is crucial. That's why 4 bit quantization models are used, which are available on Hugging Face by using the models available by unsloth(https://github.com/unslothai/unsloth). Also most of the code is based on the library provided by unsloth. For the finetuning, the following models are used:

  • gemma-1.1-7b-it-bnb-4bit
  • llama-2-7b-chat-bnb-4bit
  • llama-3-8b-Instruct-bnb-4bit
  • mistral-7b-instruct-v0.2-bnb-4bit

Much more details about the fine-tuning process can be found in the notebooks in the src/finetuning_notebooks folder.

Models trained using this codebase are available on Hugging Face:

Training Loss on all models

Training Loss

DISCLAIMER: The models are trained on a small dataset (only 2000 entries).

Repository structure

.
├── .vscode                                                 # VSCode settings
│   └── settings.json                                       # Settings for the formatting of the code
├── artifacts                                               # Artifacts generated during the training of the models
│   ├── all_models.png                                      # Training loss of all models
│   ├── gemma_loss.csv                                      # Training loss of the Gemma model per step
│   ├── gemma_loss.png                                      # Training loss of the Gemma model
│   ├── llama2_loss.csv                                     # Training loss of the Llama model per step
│   ├── llama2_loss.png                                     # Training loss of the Llama model
│   ├── llama3_loss.csv                                     # Training loss of the Llama model per step
│   ├── llama3_loss.png                                     # Training loss of the Llama model
│   ├── mistral_loss.csv                                    # Training loss of the Mistral model per step
│   ├── mistral_loss.png                                    # Training loss of the Mistral model
│   ├── trainer_stats_gemma.json                            # Trainer stats of the Gemma model
│   ├── trainer_stats_llama2.json                           # Trainer stats of the Llama model
│   ├── trainer_stats_llama3.json                           # Trainer stats of the Llama model
│   └── trainer_stats_mistral.json                          # Trainer stats of the Mistral model
├── data                                                    # Datasets used in the project
│   ├── processed_datasets                                  # Processed datasets
│   │   ├── medical_gemma_instruct_dataset                  # Processed dataset for the Gemma
│   │   ├── medical_gemma_instruct_dataset_short            # Processed dataset for the Gemma with a smaller dataset size
│   │   ├── medical_llama2_instruct_dataset                 # Processed dataset for the Llama2
│   │   ├── medical_llama2_instruct_dataset_short           # Processed dataset for the Llama2 with a smaller dataset size
│   │   ├── medical_llama3_instruct_dataset                 # Processed dataset for the Llama3
│   │   ├── medical_llama3_instruct_dataset_short           # Processed dataset for the Llama3 with a smaller dataset size
│   │   ├── medical_mistral_instruct_dataset                # Processed dataset for the Mistral
│   │   └── medical_mistral_instruct_dataset_short          # Processed dataset for the Mistral with a smaller dataset size
│   └── raw_data                                            # Raw datasets
│       ├── medical_meadow_wikidoc.csv                      # Medical Meadow Wikidoc dataset
│       └── medquad.csv                                     # Medquad dataset
├── src                                                     # Source code
│   ├── data_processing                                     # Data processing scripts
│   │   ├── create_process_datasets.py                      # Script to create processed datasets
│   │   ├── instruct_datasets.py                            # Defining the processing of the datasets to be in the instruct format
│   │   └── requirements.txt                                # Requirements for the data processing scripts
│   └── finetuning_notebooks                                # Notebooks for the fine-tuning of the LLMs
│       ├── gemma_1_1_7b_it_medical.ipynb                   # Notebook for the fine-tuning of the Gemma LLM
│       ├── llama_2_7b_chat_medical.ipynb                   # Notebook for the fine-tuning of the Llama2 LLM
│       ├── llama_3_8b_instruct_medical.ipynb               # Notebook for the fine-tuning of the Llama3 LLM
│       └── mistral_7b_instruct_v02_medical.ipynb           # Notebook for the fine-tuning of the Mistral LLM
├── .gitignore                                              # Git ignore file
└── README.md                                               # README file (this file)

llm-medical-finetuning's People

Contributors

shekswess avatar

Stargazers

UmiUni-Community avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.