Coder Social home page Coder Social logo

nyumedml / gnn_for_ehr Goto Github PK

View Code? Open in Web Editor NEW
242.0 15.0 59.0 568 KB

Code for "Graph Neural Network on Electronic Health Records for Predicting Alzheimer’s Disease"

License: GNU General Public License v3.0

Python 90.20% Jupyter Notebook 9.80%
disease-prediction alzheimer-disease-prediction graph-neural-networks deep-learning ehr pytorch electronic-health-records gnn

gnn_for_ehr's People

Contributors

jackzhu727 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gnn_for_ehr's Issues

model.py NaN values for sparse matrix multiplication

When training the VGNN, in the attention module we multiply two sparse matrices, edge_e and data, together to produce h_prime. Then the code verifies that whether h_prime contains NaN, which failed for me. I tried to check for edge_e before multiplication by converting it to dense matrix but I don't think NaN value is in it. I used assert not torch.isnan(edge_e.to_dense()).any() for checking it. Below is a detailed trace log of the error.

  File "/scratch/pw1287/GNN4EHR/utils.py", line 19, in train
    logits, kld = model(input)
  File "/home/pw1287/.conda/envs/GNN/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/pw1287/.conda/envs/GNN/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/pw1287/.conda/envs/GNN/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/pw1287/.conda/envs/GNN/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
    raise output
  File "/home/pw1287/.conda/envs/GNN/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
    output = module(*input, **kwargs)
  File "/home/pw1287/.conda/envs/GNN/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/scratch/pw1287/GNN4EHR/model.py", line 217, in forward
    outputs = [self.encoder_decoder(data[i, :]) for i in range(batch_size)]
  File "/scratch/pw1287/GNN4EHR/model.py", line 217, in <listcomp>
    outputs = [self.encoder_decoder(data[i, :]) for i in range(batch_size)]
  File "/scratch/pw1287/GNN4EHR/model.py", line 206, in encoder_decoder
    h_prime = self.out_att(output_edges, h_prime)
  File "/home/pw1287/.conda/envs/GNN/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/scratch/pw1287/GNN4EHR/model.py", line 111, in forward
    h_prime = torch.stack([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=0).mean(
  File "/scratch/pw1287/GNN4EHR/model.py", line 111, in <listcomp>
    h_prime = torch.stack([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=0).mean(
  File "/scratch/pw1287/GNN4EHR/model.py", line 99, in attention
    assert not torch.isnan(h_prime).any()

Is this an error of the input side? How do I get around this? Thanks.

Possible small bug in code?

Hi,

Thank you for making the code public. It's really nice to see!

I think the following line here

self.norm = LayerNorm(hidden_features)

should be

if concat:
    self.norm = LayerNorm(hidden_features * num_heads)

Also, apologies if I have missed this but if the output of the multi-headed attention is concatenated, shouldn't this be reflected in the size of the shared weights W at successive layers? Currently it is constant at d_in x d but it should be dK x d at intermediate layers.

self.W = clones(nn.Linear(in_features, hidden_features), num_of_heads)

Variables Not Used in preprocess_mimic.py

In the process_patient(infile) function of preprocess_mimic.py file, the below code segment declared two dictionary for storing patient_ids and relevant encounter information

    for patient_id, time_enc_tuples in patient_dict.items():
        patient_dict_sorted[patient_id] = sorted(time_enc_tuples)

The variables patient_dict and patient_dict_sorted are neither used later inside the scope of the function nor are they returned to the caller of the function. I think this part of the code could be struck? If not, could you please explain this particular code segment? Thanks.

Input data format

Hello,
Great work!
Can you please share the format/structure of the pickle files which include the input data being loaded in dataloader.py?
Thanks.

Questions about the mortality experiment

Hi,

Thank you for making the code public, excellent work!

I just have one small question, as the paper suggests, the data for mortality task will be processed with the following criteria

To avoid potential data leakage between mortality and the preventative events immediately preceding it, we only include the chart events within the first 24 hours after ICU admission as the input for the mortality prediction task.

However, the experiment used tables like DIAGNOSES_ICD and PROCEDURE_ICD, which don't seem to have a timestamp data in the table. How can I filter out ICDs that are 24 hours after admission?

Thank you very much!

XGboost results?

Hi, why you did not include XGboost results into your model scores table? It is like de-facto ML standard for table data.

image

As far as I was able to understand, you tried to predict mortality using 24h from admission data:

mortality prediction at 24 hour after admission, basedon MIMIC-III cohort

You took all MIMIC-III patients or some cohort, e.g. patients with sepsis?

XGbost results for all patient mortality prediction using 24h from admission data:
Johnson, A. E. W. & Mark, R. G. Real-time mortality prediction in the Intensive Care Unit. AMIA Annu. Symp. Proc.2017, 994–1003 (2018).

-

Ignore this issue. I confused with another paper.

where to get the csv file needed

Excellent work!

can you provide us with smaples of needed csv files? It will me help a lot for understanding the EHR data and your processing flows.

Thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.