Coder Social home page Coder Social logo

dac's Introduction

Deconfounding Actor-Critic Network with Policy Adaptation forDynamic Treatment Regimes

This repository contains the official PyTorch implementation of the following paper:

Deconfounding Actor-Critic Network with Policy Adaptation forDynamic Treatment Regimes

Abstract: Despite intense efforts in basic and clinical research, an individualized ventilation strategy for critically ill patients remains a major challenge. Recently, dynamic treatment regime (DTR) with reinforcement learning (RL) on electronic health records (EHR) has attracted interest from both the healthcare industry and machine learning research community. However, most learned DTR policies might be biased due to the existence of confounders. Although some treatment actions non-survivors received may be helpful, if confounders cause the mortality, the training of RL models guided by long-term outcomes (e.g., 90-day mortality) would punish those treatment actions causing the learned DTR policies to be suboptimal. In this study, we develop a new deconfounding actor-critic network (DAC) to learn optimal DTR policies for patients. To alleviate confounding issues, we incorporate a patient resampling module and a confounding balance module into our actor-critic framework. To avoid punishing the effective treatment actions non-survivors received, we design a short-term reward to capture patients' immediate health state changes. Combining short-term with long-term rewards could further improve the model performance. Moreover, we introduce a policy adaptation method to successfully transfer the learned model to new-source small-scale datasets. The experimental results on one semi-synthetic and two different real-world datasets show the proposed model outperforms the state-of-the-art models. The proposed model provides individualized treatment decisions for mechanical ventilation that could improve patient outcomes.

Framework

Paired survivor and non-survivor patients are resampled with similar estimated mortality risks to build balanced mini-batches. DAC adopt an actor-critic model to learn the optimal DTR policies. The longitudinal patients' data are sent to a long short-term memory network (LSTM) \cite{lstm} to generate the health state sequences. To further remove the confounding bias, a dynamic inverse probability of treatment weighting method is introduced to assign weights to the rewards at each time step for each patient and train the actor network with the weighted rewards.

Data preprocessing

List of used variables

Static variables : Age,Gender,Weight,Readmission to intensive, care, Elixhauser score (premorbid status)

Time-varying variables: Modified SOFA, SIRS, Glasgow coma scale, Heart rate, systolic, mean and diastolic, blood pressure, shock index, Respiratory rate, SpO2, Temperature Potassium, sodium, chloride, Glucose, BUN, creatinine, Magnesium, calcium, ionized calcium, carbon dioxide, SGOT, SGPT, total bilirubin, albumin, Hemoglobin, White blood cells count, platelets, count, PTT, PT, INR, pH, PaO2, PaCO2, base excess, bicarbonate, lactate, PaO2/FiO2 ratio, Mechanical ventilation, FiO2, IV fluid intake over 4h, vasopressor over 4h, Urine output over 4h, Cumulated fluid balance since admission (includes preadmission data when available)

Treatment actions: positive end-expiratory pressure (PEEP), fraction of inspired oxygen (FiO2), ideal body weight-adjusted tidal volume (Vt)

Outcome: Hospital mortality, 90-day mortality

MIMIC-III dataset

Setup MIMIC-III database (https://github.com/MIT-LCP/mimic-code)

cd preprocessing
python extract_mimic_data.py
python mimic3_dataset.py
python mech_dataset.py
python find_patients_wth_mechvent.py

AmsterdamUMCdb dataset

Setup AmsterdamUMCdb database (https://github.com/AmsterdamUMC/AmsterdamUMCdb)

cd preprocessing
extract_features.ipynb  
mechanical_ventilation.ipynb
generate_mechvent_variables.py

Synthetic dataset

Simulate the all covariates, treatments and outcomes

cd preprocessing
python synthetic_mimic.py

Train and test DAC

Split dataset

cd code
python split.py

Estimate mortality rate

cd code
python estimate_mortality.py

Model training

cd code
python main.py

Visualization of results

  • Visualization of the action distribution in the 3-dimensional action space on MIMIC-III dataset.

  • Visualization of the action distribution in the 3-dimensional action space on AmsterdamUMCdb dataset.

  • The relations between mortality rates and mechanical ventilation setting difference (recommended setting - actual setting) on MIMIC-III dataset.

  • The relations between mortality rates and mechanical ventilation setting difference (recommended setting - actual setting) on AmsterdamUMCdb dataset.

  • The positive correlations between estimated mortality rate and predicted mortality probability on MIMIC-III and AmsterdamUMCdb datasets.

  • Mortality-expected-return curve computed by the learned policies

dac's People

Contributors

yinchangchang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

dac's Issues

The file step_3_start.pkl was not found

Hello,
In the files mimic_dataset.py and mech_dataset.py, the files step_3_start.pkl, sample_and_hold, and sepsis were used, but they were not found in the data directory.

Would you please explain a bit?
Thank you!

Would you please explain a bit about data preprocessing

Hello,
I would like to ask if observations are processed and aggregated into the 4 hours windows just like literature "Development and validation of a reinforcement learning algorithm to dynamically optimize Mechanical Ventilation in Critical Care"?
And How much time each time step corresponds to? Would you please explain a bit?
Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.