Coder Social home page Coder Social logo

meds-torch's Introduction

MEDS-torch

PyTorch Lightning Config: Hydra Template
Python PyPI Hydra Tests Code Quality Contributors Pull Requests License

Description

This repository provides a comprehensive suite for advanced machine learning over Electronic Health Records (EHR) using PyTorch, PyTorch Lightning, and Hydra for configuration management. The project leverages MEDS_Polars, a robust system for transforming EHR data into a structured, tabular format that enhances the accessibility and analyzability of medical datasets. By employing a variety of tokenization strategies and neural network architectures, this framework facilitates the development and testing of models that can predict, generate, and understand complex medical trajectories.

Key features include:

  • Configurable ML Pipeline: Utilize Hydra to dynamically adjust configurations and seamlessly integrate with PyTorch Lightning for scalable training across multiple environments.
  • Advanced Tokenization Techniques: Explore different approaches to processing EHR data, such as triplet tokenization and code-specific embeddings, to capture the nuances of medical information.
  • Pre-training Strategies: Leverage contrastive learning, autoregressive token forecasting, and other pre-training techniques to boost model performance with MEDS data.
  • Transfer Learning: Implement and test transfer learning scenarios to adapt pre-trained models to new tasks or datasets effectively.
  • Generative and Supervised Models: Support for zero-shot generative models and supervised training allows for a broad application of the framework in predictive and generative tasks within healthcare.

The goal of this project is to push the boundaries of what's possible in healthcare machine learning by providing a flexible, robust, and scalable platform that accommodates a wide range of research and operational needs. Whether you're conducting academic research, developing clinical applications, or exploring new machine learning methodologies, this repository offers the tools and flexibility needed to innovate and excel in the field of medical data analysis.

Installation

Pip

# clone project
git clone [email protected]:Oufattole/meds-torch.git
cd meds-torch

# [OPTIONAL] create conda environment
conda create -n meds-torch python=3.12
conda activate meds-torch

# install pytorch according to instructions
# https://pytorch.org/get-started/

# install requirements
pip install -e .

How to run

Train model with default configuration

# train on CPU
python -m meds_torch.train trainer=cpu

# train on GPU
python -m meds_torch.train trainer=gpu

Train model with chosen experiment configuration from configs/experiment/

python -m meds_torch.train experiment=experiment_name.yaml

You can override any parameter from command line like this

python -m meds_torch.train trainer.max_epochs=20 data.batch_size=64

📌  Introduction

Why you might want to use it:

✅ Save on boilerplate
Easily add new models, datasets, tasks, experiments, and train on different accelerators, like multi-GPU, TPU or SLURM clusters.

✅ Support different tokenization methods for EHR data

  • Triplet Tokenization -- add to read the docs explanations of each subtype
  • Everything is text -- add to read the docs explanations of each subtype
  • Everything is a code TODO -- add to read the docs explanations of each subtype

✅ MEDS data pretraining (and Transfer Learning Support)

  • General Contrastive window Pretraining
  • STraTS Value Forecasting
  • Autoregressive Token Forecasting
  • Token Masked Imputation

✅ Zero shot Generative Model Support

  • Allow support for generating meds format future trajectories for patients using the Autoregressive Token Forecasting.

✅ Supervised Model Support

  • randomly initialize a model and train it in a supervised maner on your MEDS format medical data.
  • Load pretrained model weights

✅ Education
Thoroughly commented. You can use this repo as a learning resource.

✅ Reusability
Collection of useful MLOps tools, configs, and code snippets. You can use this repo as a reference for various utilities.

Why you might not want to use it:

❌ Things break from time to time
Lightning and Hydra are still evolving and integrate many libraries, which means sometimes things break. For the list of currently known problems visit this page.

❌ Not adjusted for data engineering
Template is not really adjusted for building data pipelines that depend on each other. It's more efficient to use it for model prototyping on ready-to-use data.

❌ Overfitted to simple use case
The configuration setup is built with simple lightning training in mind. You might need to put some effort to adjust it for different use cases, e.g. lightning fabric.

❌ Might not support your workflow
For example, you can't resume hydra-based multirun or hyperparameter search.

Loggers

By default wandb logger is installed with the repo. Please install a different logger below if you wish to use it:

# neptune-client
# mlflow
# comet-ml
# aim>=3.16.2  # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550

Development Help

pytest-instafail shows failures and errors instantly instead of waiting until the end of test session, run it with:

pytest --instafail

To run failing tests continuously each time you edit code until they pass:

pytest --looponfail

To run tests on 8 parallel workers run:

pytest -n 8

meds-torch's People

Contributors

oufattole avatar teyaberg avatar aleksiakolo avatar

Stargazers

Simon Lee avatar Amrit Krishnan avatar Hyewon Jeong avatar  avatar  avatar

Watchers

Matthew McDermott avatar  avatar

meds-torch's Issues

Transfer Learning Methods

We can add additional classes that inherit the supervised model class
For now the implementation will be Triplet tokenizaiton specific, but we should generalize it in the future. These are the methods:

  • #18
  • Masked Imputation
  • Value Forecasting (in a predetermined window)
  • Autoregressive Forecasting

Everything is a Token

As opposed to triplet embeddings, we should try an everything is a token approach used in past works : CEHR BERT, ETHOS

For example, imagine a patient has a time series of two observations: a potassium lab in quantile 9/10, and one day later a creatinine lab in quantile 2/10.

  • We could define this as three tokens:
  1. quantile 9/10 potassium lab
  2. 1-day time gap
  3. quantile 2/10 creatinine lab
  • We could also define this as 5 tokens:
  1. potassium lab
  2. quantile 9/10 quantile
  3. 1-day time gap
  4. creatinine lab
  5. quantile 2/10

Let's support both!

There's a nice figure in the ethos paper of this:
image

Support Multi-task learning

We currently support single task supervised training, we should support multiple tasks. Users would input a list of task parquet files (with patient_id, end-point timestamp, and labels) and we should join these dataframes and then integrate them into the pytorch_dataset class here.

Generate Embeddings

We should provide support for users to get timestamped embeddings given a pretrained model.

all_text: Convert the entire patient history into text and use a language model to get an embedding

Pipeline:

  1. Convert history to a big string (use same approach as for observation_text and then concatenate all of those strings)
  2. load tokenizer (could be from pretrained huggingface model) and tokenize into integers
  3. Allow loading of pretrained language models and generation of a representation for downstream tasks (maybe we can do mamba mamba-130m-hf, Masked imputation model: bert, and autoregressive transformer: microsoft/phi-1_5.
    We should add some caching support later, maybe just using safetensors with a dictionary from event ID to the tensor.

Autoregressive Modeling

Generative modeling -- triplet works
https://github.com/mmcdermott/EventStreamGPT/blob/main/EventStream/transformer/generation/generation_utils.py#L73

Simplify this code^
With triplet code - you have resolved one form of conditional independance

Code helps with taking the time prediction then passing that back to the model to get code, then doing it again to get value prediction

Solutions:

  • conditionally independent -- time code value - sometimes same time sometimes different
  • nested events
    • ES-GPT style
      • performance benefit of nesting approach - reduces sequence length significantly
      • reflects conditional dependance
        • time -> code -> is value observed -> value
      • makes preprocessing less effective
    • Ethos - style
      • add in time tokens
      • only reflects
        • time -> code, value
  • Do KEY VALUE caching-- only generate for the newest embedding
  • 3 2 2 2 - triplet

Overall Config Setup

Support for Modular Configuration with Early and Late Fusion Options

Problem

Current model configurations lack the flexibility to easily incorporate and experiment with early and late fusion techniques, which are crucial for enhancing model performance by integrating information at different stages of the processing pipeline.

Proposed Solution

Develop a modular configuration that supports user defined models and losses:

  • Data Processing: Similar initial data processing step.
  • Input Encoder: Encode data as in the standard pathway.
  • Model: A more flexible stage where the model can implement any form of fusion (early or late) and handle data labeling internally according to specific experimental needs. The inputs to this stage are a sequence model architecture, for example an LSTM or a transformer decoder, and it takes the output of the input encoder as input. if there is a specific pretraining task (such as performing early fusion and shuffling windows for OCP, that is performed at this stage). Pretraining and finetuning models will be seperate files, and weight loading should be supported between them

EventStream Model Support

We should add support for EventStream models. The tokenization is already supported in the pytorch dataset class, just set the collate type to event_stream in your hyda config. I think we just need to copy some code from ESGPT github to run this.

Exploring different temporal encoding Strategies

For triplet encoding we use a one to many feed forward network for the time delta and add sum this with the code and value token to get a triplet token. We can try other temporal position methods.

  • We can do absolute (time after first observation or after birth -- i.e. age) or relative (time from the previous observation) -- the curent implementation.
  • We can try a one to many feed forward network--the current implementation--or we can try a positional encodings using sine and cosines (maybe from here)

Embedding tokens

Screenshot 2024-07-03 at 6 11 40 PM I'm trying to implement this observation level embedder (as described in the figure below) to convert the pytorch dataset batches to a sequence of embeddings, one for each observation, with static variables preceding dynamic variables.

Where do I find the vocab size in the pytorch dataset class?

Contrastive Methods

We want to support general contrastive learning window definitions.

There are some common patient specific latent space temporal structures we may desire (as shown in Figure A).

  • Consistent - that a patient has a constant representation if we look at two different windows of time for them. This can be local (only for adjacent windows) or global (for any two randomly selected non-overlapping windows)
  • Continuity/interpolation - that if we take two non adjacent windows for the patient, the window between them should be approximately an average of those two windows.
  • Ordering - that if we reverse time, or shuffle windows and input those to a model, the representation will be very different. This can also be local (adjacent windows) or global (no adjacency constraint).
    We can also have multi-patient properties:
  • Label-Based Neighbors - patients with the same label should have similar representations

Additionally, we want to allow users to select global and local windows a
Subtype_Representation_Properties

Adding dummy data

Currently we only have 4 training examples, 1 val example, and 0 test. Let's make a larger meds dataset, run meds-polars on it and copy the output into the tests/test_data folder. Ideally we should have around 128 patients with multiple tasks defined (mortality and los).

We can just duplicate rows from this branch of meds_polars for now: https://github.com/mmcdermott/MEDS_polars_functions/blob/dfd0fa7fe9d844121d0a9dce5f71d68e6340b9df/tests/test_preprocessing.py#L131C19-L131C27

Implementation of a Masking Stage with Random Masking Options

Implementation of a Masking Stage with Random Masking Options

Problem

The absence of a dedicated masking stage in our pipeline limits our ability to handle incomplete or noisy data effectively during model training.

Proposed Solution

Introduce a masking stage designed to randomly mask a specified percentage of the data or subsequences within the data:

  • Position: Place the masking stage after the input encoder and before the sequence model.
  • Functionality:
    • Support random masking, either a random percentage of the tokens are masked or a randomly sampled continuous subsequence is masked.
    • We should add to the batch a key indicating the labels that will be used by the Model stage to compute masked imputation loss.
  • Configurability: Allow users to set the percentage of data to mask.

Implement Multi-Patient Contrastive Learning

Implement Multi-Patient Contrastive Learning

Problem

Lack of support for contrastive learning across multiple patients, which is required by some contrastive learning pretraining tasks. NCL is the one I have in mind. These pretraining methods often use the supervised label or dynamic time warping within a batch, so maybe we can add a stage for batch mining.

Proposed Solution

  • Patient ID in the batch: Include patient_id in the window data struct to allow for tracking the patient and mining positive and negative pairs

Everything is Text Tokenization

We can convert timestamp, code name, and value into text description, and then use BERT or some LM to embed this into a vector. This can allow the method to generalize to any EHR (fingers crossed).

So we are going with three version of this

  • code_text: code text is fed to a language model and converted to a token which we then sum with the time and numerical value vectors.
  • observation_text: Convert the triplet (time, code, numerical_value) into text, which we use a language model to get a vector for.
  • #17

Enhance Patient Sampling Using PyTorch Samplers

Enhance Patient Sampling Using PyTorch Samplers

Problem

Current patient sampling methods at the patient-level are limited and could benefit from integration with PyTorch samplers for more flexibility.

Proposed Solution

  • Integration of PyTorch Samplers: Allow the use of PyTorch's built-in samplers to facilitate more dynamic and statistically robust patient sampling methods.
  • Update Class Structures: Remove the patient_sampling manual implementation.

Multimodal Tokenization

Currently, we only support triplet tokenization, which takes a triple (code, value, time), generates vectors for each of the three and sums those vectors to produce a token. We should add support for

  1. An arbitrary user defined encoder to tokenize a modality
  2. Use of frozen embeddings. I.e. a user can define pre-cached tokens (of arbitrary shape) as inputs.

Noise Augmentation Stage

We can add stages that add noise based on TS-TCC.

We can support different kinds of noise augmentations based on the TS-TCC model.

This paper defines a pretraining task for unlabeled time-series data based on applying augmentations to the data and applying simclr to them. It uses two types of data augmentation, termed as "strong" and "weak" augmentations, to create different views of the data for the learning process.

Strong Augmentation: Applies more intense modifications to the data, which may include drastic changes like shuffling parts of the data sequence, adding significant noise, or other alterations that substantially change the data's original structure.

Weak Augmentation: Involves less intrusive changes such as slight jittering or scaling. These augmentations retain more of the data's original characteristics compared to strong augmentations.

The pretraining task is essentially predicting future embeddings using an autoregressive model while aligning representations derived from both weakly and strongly transformed data (simclr loss is used I believe). It will require significant modifications to work on categorical data however, we probably can do additions of gaussian noise along with locally shuffling the order of events or tokens.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.