Coder Social home page Coder Social logo

vis_dec_neurips's Introduction

Contrast, Attend and Diffuse to Decode High-Resolution Images from Brain Activities

Jingyuan Sun*, Mingxiao Li*, Zijiao Chen, Yunhao Zhang, Shaonan Wang and Marie-Francine Moens. In Proceedings of the Neural Information Processing Systems 2023 (NeurIPS'23).

PWC

1. Abstract

Decoding visual stimuli from neural responses recorded by functional Magnetic Resonance Imaging (fMRI) presents an intriguing intersection between cognitive neuroscience and machine learning, promising advancements in understanding human visual perception. However, the task is challenging due to the noisy nature of fMRI signals and the intricate pattern of brain visual representations. To mitigate these challenges, we introduce a two-phase fMRI representation learning framework. The first phase pre-trains an fMRI feature learner with a proposed Double-contrastive Mask Auto-encoder to learn denoised representations. The second phase tunes the feature learner to attend to neural activation patterns most informative for visual reconstruction with guidance from an image auto-encoder. The optimized fMRI feature learner then conditions a latent diffusion model to reconstruct image stimuli from brain activities. Experimental results demonstrate our model's superiority in generating high-resolution and semantically accurate images, substantially exceeding previous state-of-the-art methods by $39.34%$ in the 50-way-top-1 semantic classification accuracy. The code implementations will be available at

2. Method Overview

flowchar-img We propose a double-phase fMRI representation learning framework. In Phase 1, we pre-train an MAE with a contrastive loss to learn fMRI representations from unlabeled data. After pre-training in Phase 1, we tune the fMRI auto-encoder with an image auto-encoder. After FRL Phase 1 and Phase 2, we apply the representation learned by the fMRI auto-encoder as conditions to tune the LDM and generate the image from the brain activities.

3. Training Procedure

3.0 Setting Environments

Create and activate conda environment named vis_dec from this env.yaml

conda env create -f env.yaml
conda activate vis_dec

3.1 FRL Phase 1

Overview

In Phase 1, we pre-train an MAE with a contrastive loss to learn fMRI representations from unlabeled fMRI data from HCP. The masking which sets a certain portion of the input data to zero targets the spatial redundancy of fMRI data. The calculation of recovering the original data from the remaining after masking suppresses noises. Optimization of the contrastive loss discerns common patterns of brain activities over individual variances.

Preparing Data

In this phase, we use fMRI samples released by HCP as pretraining data. Due to size limitations and licensing constraints, please download from the official website (https://db.humanconnectome.org/data/projects/HCP_1200), put them in the ./data/HCP directory and preprocess the data with ./data/HCP/preprocess_hcp.py. Resulting data and directory looks like:

/data
โ”ฃ ๐Ÿ“‚ HCP
โ”ƒ   โ”ฃ ๐Ÿ“‚ npz
โ”ƒ   โ”ƒ   โ”ฃ ๐Ÿ“‚ dummy_sub_01
โ”ƒ   โ”ƒ   โ”ƒ   โ”— HCP_visual_voxel.npz
โ”ƒ   โ”ƒ   โ”ฃ ๐Ÿ“‚ dummy_sub_02
โ”ƒ   โ”ƒ   โ”ƒ   โ”— ...

Training Model

You can run

python -m torch.distributed.launch โ€”nproc_per_node=1  code/phase1_pretrain_contrast.py \
--output_path . \  
--contrast_loss_weight 1 \
โ€”-batch_size 250 \
โ€”-do_self_contrast True \
โ€”-do_cross_contrast True \
--self_contrast_loss_weight 1 \ 
--cross_contrast_loss_weight 0.5 \
โ€”mask_ratio 0.75 \
โ€”num_epoch 140 

to pretrain the model by youself. do_self_contrast and do_contrast_contrast control whether or not self_contrast and contrast_contrast loss are used. self_contrast_loss_weight and cross_contrast_loss_weight denote the weight of self-contrast and cross-contrast loss in the joint loss.

You can also download our pretrained ckpt from https://1drv.ms/u/s!AlmPyF18ti-A3XmuKMPEfVNdvmsT?e=3bZ0jj

3.2 FRL Phase 2

Overview

After pre-training in Phase 1, we tune the fMRI auto-encoder with an image auto-encoder. We expect the pixel-level guidance from the image auto-encoder to support the fMRI auto-encoder in disentangling and attending to brain signals related to vision processing.

Preparing Data

We use the Generic Object Decoding (GOD) and BOLD5000 dataset in this phase. GOD is a specialized resource developed for fMRI-based decoding. It aggregates fMRI data gathered through the presentation of images from 200 representative object categories, originating from the 2011 fall release of ImageNet. The training session incorporated 1,200 images (8 per category from 150 distinct object categories). The test session included 50 images (one from each of the 50 object categories). The categories in the test session were unique from those in the training session and were introduced in a randomized sequence across runs. On five subjects the fMRI scanning was conducted. BOLD5000 is a result of an extensive slow event-related human brain fMRI study. It comprises 5,254 images, with 4,916 of them being unique. The images in BOLD5000 were selected from three popular computer vision datasets: ImageNet, COCO, and Scenes.

We provided processed versions of these datasets which can be downloaded from https://1drv.ms/u/s!AlmPyF18ti-A3Xec-3-PdsaO230u?e=ivcd7L Please download and uncompress it into the ./data. Resulting directory looks like:


โ”ฃ ๐Ÿ“‚ Kamitani
โ”ƒ   โ”ฃ ๐Ÿ“‚ npz
โ”ƒ   โ”ƒ   โ”— ๐Ÿ“œ sbj_1.npz
โ”ƒ   โ”ƒ   โ”— ๐Ÿ“œ sbj_2.npz
โ”ƒ   โ”ƒ   โ”— ๐Ÿ“œ sbj_3.npz
โ”ƒ   โ”ƒ   โ”— ๐Ÿ“œ sbj_4.npz
โ”ƒ   โ”ƒ   โ”— ๐Ÿ“œ sbj_5.npz
โ”ƒ   โ”ƒ   โ”— ๐Ÿ“œ images_256.npz
โ”ƒ   โ”ƒ   โ”— ๐Ÿ“œ imagenet_class_index.json
โ”ƒ   โ”ƒ   โ”— ๐Ÿ“œ imagenet_training_label.csv
โ”ƒ   โ”ƒ   โ”— ๐Ÿ“œ imagenet_testing_label.csv

โ”ฃ ๐Ÿ“‚ BOLD5000
โ”ƒ   โ”ฃ ๐Ÿ“‚ BOLD5000_GLMsingle_ROI_betas
โ”ƒ   โ”ƒ   โ”ฃ ๐Ÿ“‚ py
โ”ƒ   โ”ƒ   โ”ƒ   โ”— CSI1_GLMbetas-TYPED-FITHRF-GLMDENOISE-RR_allses_LHEarlyVis.npy
โ”ƒ   โ”ƒ   โ”ƒ   โ”— ...
โ”ƒ   โ”ƒ   โ”ƒ   โ”— CSIx_GLMbetas-TYPED-FITHRF-GLMDENOISE-RR_allses_xx.npy
โ”ƒ   โ”ฃ ๐Ÿ“‚ BOLD5000_Stimuli
โ”ƒ   โ”ƒ   โ”ฃ ๐Ÿ“‚ Image_Labels
โ”ƒ   โ”ƒ   โ”ฃ ๐Ÿ“‚ Scene_Stimuli
โ”ƒ   โ”ƒ   โ”ฃ ๐Ÿ“‚ Stimuli_Presentation_Lists

Training Model

You can run the following commands to get the fMRI encoder that we use to produce the reported reconstruction performance on GOD subject 3 in the paper.

python -m torch.distributed.launch --nproc_per_node=4 code/phase2_finetune_cross.py \
--dataset GOD \
--pretrain_mbm_path your_pretrained_ckpt_from_phase1 \
--batch_size 4 \
--num_epoch 60 \
--fmri_decoder_layers 6 \
--img_decoder_layers 6 \
--fmri_recon_weight 0.25 \ 
--img_recon_weight 1.5 \
--output_path your_output_path \ 
--img_mask_ratio 0.5 \
--mask_ratio 0.75 

You can also download our trained ckpt from https://1drv.ms/u/s!AlmPyF18ti-A3XjJEkOfBELTl71W?e=wlihVF

3.3 Tuning LDM

Overview

Tuning Model

You can run the following commands to produce the reported reconstruction performance on GOD subject 3 in the paper.

python code/ldm_finetune.py --pretrain_mbm_path your_phase2_ckpt_path \
--num_epoch 700 \
--batch_size 8 \
--is_cross_mae \
--dataset GOD \
--kam_subs sbj_3 \
--target_sub_train_proportion 1. 
--lr 5.3e-5

Acknowledgements

A large part of the code is inheritated from our previous work Mind-Vis . We express our gratitude to the following entities for generously sharing their raw and pre-processed data with the public: Kamitani Lab, Weizmann Vision Lab, and the BOLD5000 team. Our implementation of Masked Brain Modeling is built upon Masked Autoencoders by Facebook Research, and our Conditional Latent Diffusion Model implementation is based on the work found in the Latent Diffusion Model repository from CompVis. We extend our appreciation to these authors for openly sharing their code and checkpoints.

Citation

Please cite our paper if the code should be useful for you.

@inproceedings{
sun2023contrast,
title={Contrast, Attend and Diffuse to Decode High-Resolution Images from Brain Activities},
author={Jingyuan Sun and Mingxiao Li and Yunhao Zhang and Marie-Francine Moens and Zijiao Chen and Shaonan Wang},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=YZSLDEE0mw}
}

vis_dec_neurips's People

Contributors

soinx0629 avatar mingxiao-li avatar

Stargazers

Simon Fei avatar Linyang He avatar  avatar Yeray avatar minghao avatar  avatar  avatar Jiaxuan Chen avatar rzxia avatar yszhou avatar  avatar  avatar  avatar Michael Mai avatar JiaJun Zhu avatar ๊น€ํƒœ์—ฝ avatar YuLongZJU avatar MingTao(้™ถๆ˜Ž) avatar Vinson avatar  avatar JianxiongGao avatar  avatar  avatar Mang Ning avatar James Chang avatar  avatar Just่Šฑ่ฏญ avatar  avatar Bru avatar  avatar  avatar  avatar Abner  avatar  avatar xmu-xiaoma666 avatar  avatar  avatar linjng cao avatar

Watchers

Kostas Georgiou avatar  avatar

vis_dec_neurips's Issues

LDM Model Config File

I'm trying to train the LDM myself, but can't find the config file that referenced in the following line of code:

class fLDM:
def __init__(self, metafile, num_voxels, device=torch.device('cpu'),
pretrain_root='../pretrains/ldm/label2img',
logger=None, ddim_steps=250, global_pool=True, use_time_cond=True):
self.ckp_path = os.path.join(pretrain_root, 'model.ckpt')
self.config_path = os.path.join(pretrain_root, 'config.yaml')
config = OmegaConf.load(self.config_path)

Would it be possible to upload that file to GitHub too? It would also be great if the trained LDM checkpoint was available for comparison

HCPๆ•ฐๆฎ้›†็ฝ‘็ซ™ๆ‰“ๅผ€ไน‹ๅŽ๏ผŒๆ— ๆณ•ไธ‹่ฝฝ

ๆœฌไบบ่‚‰่บซๅœจ็พŽๅ›ฝ๏ผŒๆ‰“ๅผ€ไน‹ๅŽ่ฎฉๆˆ‘ๅฎ‰่ฃ…ๆŸไธชๆ’ไปถaspera Connect ๅฎ‰่ฃ…็จ‹ๅบ๏ผŒไฝ†ๆ˜ฏๅฎŒๅ…จๆฒกๆœ‰ไธ‹ไธ€ๆญฅ้€‰้กนใ€‚
่€Œไธ”้‡Œ้ขๆœ‰ๅพˆๅคš่ต„ๆบ๏ผŒๅซ1200 subjects็š„ๆœ‰ๅพˆๅคš๏ผŒๅณไฝฟ่ƒฝไธ‹่ฝฝไนŸไธ็Ÿฅ้“ๆ”นไธ‹ๅ“ชไธ€ไธช๏ผŒๆœ€ๅฐ็š„2.3gb๏ผŒๅคง็š„ๅ‡ ไธชtใ€‚

Unable to replicate the stageB LDM finetune

This is indeed intriguing work!

Would it be possible for you to share the LDM pre-trained model model.ckpt and the associated configuration file config.yaml used in the stageB_ldm_finetune?
Alternatively, could you kindly guide me on where I might find these files?

image

The HCP datasets problem

Thanks for your great job and code!
I am currently exploring the HCP data and noticed that you provided a website link.
Could you please specify which dataset from the provided link you used in your research๏ผŸ
And could you kindly confirm if the image shown below corresponds to the dataset you utilized in your work?

Thank you very much for your time and assistance!
HCP

fMRI Encoder checkpoints

Hi, are the weights of the fMRI pre-trained encoder available? I would like to use the pre-trained model as it is, without further pre-training

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.