Coder Social home page Coder Social logo

ai-med / triplettraining Goto Github PK

View Code? Open in Web Editor NEW
2.0 2.0 1.0 449 KB

Official PyTorch Implementation for From Barlow Twins to Triplet Training: Differentiating Dementia with Limited Data - MIDL 2024

License: GNU General Public License v3.0

Python 100.00%

triplettraining's People

Contributors

yiiitong avatar

Stargazers

 avatar  avatar

Forkers

yiiitong

triplettraining's Issues

HDF5 file creation - UKB MRI data

Hello,
I would like to produce h5 files for UKB pre-training as done in your method. However, I encounter memory problems while generating it.
Would you have advices on a specific process you followed to integrate your N = 39, 560 MRI samples (with a specific split I guess to have train_data.h5 and valid_data.h5 ?

I am working on the dna nexus platform, with an instance of type mem2_ssd1_gpu1_x32 with 129 GB total memory, 837 GB total storage and 32 cores. I use T2-FLAIR MRI with .nii.gz format, and each of them is 2.3 MB.
The detailed usage is as follows:
image

My python script uses multi-processing to generate temporary h5 files and combining them in a final data_train.h5:

import h5py
import pandas as pd
import numpy as np
import os
import nibabel as nib
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
import sys
import time
import logging

# Set up logging
logging.basicConfig(filename='process.log', level=logging.INFO, format='%(asctime)s %(message)s')

def load_image(image_path):
    img = nib.load(image_path).get_fdata()
    return img

def process_batches_to_h5(batch_eids, image_folder, temp_file_path, progress_position, progress_queue):
    no_image_found = []
    with h5py.File(temp_file_path, 'w', libver='latest') as file:
        for eid in batch_eids:
            start_time = time.time()
            image_path = os.path.join(image_folder, f"{eid}/{eid}_T2_FLAIR_brain_to_MNI.nii.gz")
            if os.path.exists(image_path):
                img = load_image(image_path)
                group = file.create_group(str(eid))
                group.create_dataset('MRI/T2/data', data=img.astype(np.float32))
            else:
                no_image_found.append(eid)
                continue
            end_time = time.time()
            process_time = end_time - start_time
            log_message = f"Patient {eid} processed in {process_time:.2f} seconds"
            progress_queue.put((progress_position, log_message, process_time))
            logging.info(log_message)
    return no_image_found

def combine_hdf5_files(output_file_path, temp_file_paths):
    with h5py.File(output_file_path, 'w', libver='latest') as output_file:
        for temp_file_path in temp_file_paths:
            with h5py.File(temp_file_path, 'r') as temp_file:
                for eid in temp_file.keys():
                    temp_file.copy(eid, output_file)
            os.remove(temp_file_path)

def create_hdf5_from_eids(csv_file_path, image_folder, output_file_path, batch_size=1000, num_workers=32):
    df = pd.read_csv(csv_file_path)
    eids = df['eid'].values
    num_batches = len(eids) // batch_size + 1
    batched_eids = np.array_split(eids, num_batches)

    manager = multiprocessing.Manager()
    progress_queue = manager.Queue()

    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        futures = []
        temp_file_paths = []

        for i, batch_eids in enumerate(batched_eids):
            temp_file_path = f'temp_{i}.h5'
            temp_file_paths.append(temp_file_path)
            future = executor.submit(process_batches_to_h5, batch_eids, image_folder, temp_file_path, i, progress_queue)
            futures.append(future)

        total_patients = len(eids)
        progress_bar = tqdm(total=total_patients, desc="Overall Progress", leave=True, file=sys.stdout)

        while any(future.running() for future in futures):
            try:
                while True:
                    batch_position, message, process_time = progress_queue.get_nowait()
                    progress_bar.update(1)
                    tqdm.write(message)
            except multiprocessing.queues.Empty:
                pass

        progress_bar.close()

        no_image_found_all = []
        for future in futures:
            no_image_found_all.extend(future.result())

        combine_hdf5_files(output_file_path, temp_file_paths)

    if no_image_found_all:
        logging.info("Eids for which no image was found:")
        logging.info(no_image_found_all)
    logging.info(f"HDF5 file {output_file_path} created successfully.")

# Define paths
csv_file_path_train = "mri_eids_UKB_train.csv"
csv_file_path_val = "mri_eids_UKB_val.csv"
image_folder = '/mnt/project/Data/brain_MRI/T2_lesion_T1seg/'
train_output_file_path = 'train_data.h5'
val_output_file_path = 'valid_data.h5'

# Create HDF5 files for training and validation sets
create_hdf5_from_eids(csv_file_path_train, image_folder, train_output_file_path)
create_hdf5_from_eids(csv_file_path_val, image_folder, val_output_file_path)

print(f"Training data saved to {train_output_file_path}")
print(f"Validation data saved to {val_output_file_path}")

I have to face a trade-off of time and space because the time to generate the files is longer when adding this line to lower down the space taken by each temporary h5 file:
group.create_dataset('MRI/T2/data', data=img.astype(np.float32), chunks=True, compression="gzip") but removing it is generating the following error:

RuntimeError: Dirty entry flush destroy failed (file write failed: time = Wed May 22 14:59:05 2024, filename = 'temp_0.h5', file descriptor = 16, errno = 28, error message = 'No space left on device', buf = 0x56328dc1aeb0, total write size = 1891, bytes this sub-write = 1891, bytes actually written = 18446744073709551615, offset = 0)

For now, I have 32 temporary files (num processors = max num_workers = 32), each of them with target size of 1000 patients. Currently (i.e. after the script terminated), they are all 12 GB in size. On dna nexus platform, I load the data from the project stored on the cloud and I write these files to the local temporary environment from which I can transfer final files onto the cloud after creation.
My script was at this point when I got out of storage:
Overall Progress: 49%|█████████████████████████████▊ | 15615/31913 [3:10:47<3:19:07, 1.36it/s]

My T2-FLAIR MRIs are registered to MNI atlas and are (182, 218, 182) in shape.

I would love to have advices/inputs from your side on faster/more efficient storage gestion for the file creation!

Best and thank you for your answer,

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.