Coder Social home page Coder Social logo

med-flamingo's Introduction

Med-Flamingo

This is the code repo for the Med-Flamingo paper.

More updates to follow soon!

Setup

Create virtual environment, e.g.:

$ virtualenv flam_env
$ source flam_env/bin/activate

Install dependencies: (we assume GPU device / cuda available)

$ source install.sh

Setting up Llama-7B (v1) locally

Due to some recent changes in tokenizer class names, directly using the hf space may lead to problems.
We recommend to manually download the model, e.g. in a new dir models/ the following way:

$ git lfs install
$ git clone https://huggingface.co/decapoda-research/llama-7b-hf

In tokenizer_config.json, set:
"tokenizer_class": "LlamaTokenizer"
Now, you should be all set.

Demo

  1. Go to scripts/

  2. Edit demo.py and enter your Llama-7B path (v1).

  3. Run:

$python demo.py

Citing

If you found this repository interesting, please consider citing our pre-print:

@article{moor2023medflamingo,
    title={Med-Flamingo: A Multimodal Medical Few-shot Learner},
    author={Moor, Michael and Huang, Qian and Wu, Shirley and Yasunaga, Michihiro and Zakka, Cyril and Dalmia, Yash and Reis, Eduardo Pontes and Rajpurkar, Pranav and Leskovec, Jure},
    year={2023},
    month={July},
    note={arXiv:2307.15189},
    url={https://arxiv.org/abs/2307.15189}
}  

Furthermore, the following two references enabled our project in the first place:

@software{anas_awadalla_2023_7733589,
  author = {Awadalla, Anas and Gao, Irena and Gardner, Joshua and Hessel, Jack and Hanafy, Yusuf and Zhu, Wanrong and Marathe, Kalyani and Bitton, Yonatan and Gadre, Samir and Jitsev, Jenia and Kornblith, Simon and Koh, Pang Wei and Ilharco, Gabriel and Wortsman, Mitchell and Schmidt, Ludwig},
  title = {OpenFlamingo},
  month        = mar,
  year         = 2023,
  publisher    = {Zenodo},
  version      = {v0.1.1},
  doi          = {10.5281/zenodo.7733589},
  url          = {https://doi.org/10.5281/zenodo.7733589}
}
@article{Alayrac2022FlamingoAV,
  title={Flamingo: a Visual Language Model for Few-Shot Learning},
  author={Jean-Baptiste Alayrac and Jeff Donahue and Pauline Luc and Antoine Miech and Iain Barr and Yana Hasson and Karel Lenc and Arthur Mensch and Katie Millican and Malcolm Reynolds and Roman Ring and Eliza Rutherford and Serkan Cabi and Tengda Han and Zhitao Gong and Sina Samangooei and Marianne Monteiro and Jacob Menick and Sebastian Borgeaud and Andy Brock and Aida Nematzadeh and Sahand Sharifzadeh and Mikolaj Binkowski and Ricardo Barreira and Oriol Vinyals and Andrew Zisserman and Karen Simonyan},
  journal={ArXiv},
  year={2022},
  volume={abs/2204.14198}
}

med-flamingo's People

Contributors

mi92 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

med-flamingo's Issues

Prompting "Med-Flamingo issues

I've been using the 'Med-Flamingo' model for medical text generation and have frequently encountered issues like hallucinations and inaccuracies. I'm looking for insights into the model's training data and effective ways to improve prompts to achieve the results mentioned in the paper.
@mi92 any guidance on these aspects would be greatly appreciated. Thank you!

the model always output `pregnancy` for diagnosis questions

Dear team,

Thanks for sharing the great work.

I'd like to use the model for diagnosis. Here is my prompt

prompt = "You are a helpful medical assistant. You are being provided with images, a question about the image and an answer. Follow the examples and answer the last question. <image>Question: What is/are the structure near/in the middle of the brain? Answer: pons.<|endofchunk|><image>Question: Is there evidence of a right apical pneumothorax on this chest x-ray? Answer: yes.<|endofchunk|><image>Question: Is/Are there air in the patient's peritoneal cavity? Answer: no.<|endofchunk|><image>Question: Does the heart appear enlarged? Answer: yes.<|endofchunk|><image>Question: What side are the infarcts located? Answer: bilateral.<|endofchunk|><image>Question: Which image modality is this? Answer: mr flair.<|endofchunk|><image>Question: What is the most likely diagnosis? Answer:"

However, the model always outputs pregnancy for all kinds of CT and MR images. What should be the correct prompts for diagnosis questions?

Dead repo

The repo for llama is dead now, does the implementation work with other repos?

Best regards,
Benedikt

always give same output for different inputs

Hi

I tried to use med-flamingo on a medical image dataset, the task is to predict which view of the heart is presented in an echocardiogram image. I follow your demo, using your default in-context example images and prompt, only modifying the question

prompt = "You are a helpful medical assistant in Echardiology. You are being provided with images, a question about the image and an answer. Follow the examples and answer the last question. Question: What is/are the structure near/in the middle of the brain? Answer: pons.<|endofchunk|>Question: Is there evidence of a right apical pneumothorax on this chest x-ray? Answer: yes.<|endofchunk|>Question: Is/Are there air in the patient's peritoneal cavity? Answer: no.<|endofchunk|>Question: Does the heart appear enlarged? Answer: yes.<|endofchunk|>Question: What side are the infarcts located? Answer: bilateral.<|endofchunk|>Question: Which image modality is this? Answer: mr flair.<|endofchunk|>Question: Is the image presenting an A4C or A2C or PLAX or PSAX or other view of the heart? Answer:"

However no matter what actual query image i use, it always give the same output?

Llama2

Loving med-flamingo. However, would also love to use Llama2. Is this likely to come?

Any plans for creating a demo website ?

Hello
We were trying to run the model on in-house CXR dataset but were not able to get meaningful results.
So we wanted to check if we are using it the right way. Are there any plans that a website that gives a demo of this model will be released later ? That will help in verifying if we are using the model correctly.

Model training and fInetuning scripts

Dear authors, thanks for your wonderful work. Can you provide scripts for training and the detailed data format to prepare a dataset to finetune med-flamingo? This will help us to investigate the result of providing more data to improve flamingo-like models. @mi92

How to calculate the F1 BERT score?

I noticed that the BERT score is used as the metric.
I also want to use it to evaluate the performance of different medical LLMs, but when I used the code from https://github.com/Tiiiger/bert_score to calculate it from the GT and predicted answer, I found the generated score is a vector like tensor([0.9834, 0.9782, 0.9162, 0.9589, 0.9675, 0.9680, 0.9602, 0.9663, 0.9438, 0.9508]).
How to get the final metric from generated tensor?

Thank you for opening! It's a valuable work for the Community.

Benchmark dataset availability

In the paper you use you created your own USMLE-derived benchmark, as well as creating improved splits of existing VQA benchmarks to address data leakage. Could any of these benchmarks be made available (either publicly or privately on request)? Thanks!

Running demo.py file on 2 gpu

I want to run demo.py on 2 gpu so I used command "accelerate demo.py" on terminal.
it gives ChildFailedError

accelerate env :

Accelerate version: 0.20.0.dev0
Platform: Linux-5.15.0-78-generic-x86_64-with-glibc2.35
Python version: 3.10.12
Numpy version: 1.24.2
PyTorch version (GPU?): 2.0.0+cu117 (True)
PyTorch XPU available: False
System RAM: 31.34 GB
GPU type: NVIDIA GeForce RTX 4090
Accelerate default config:
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
mixed_precision: no
use_cpu: False
num_processes: 2
machine_rank: 0
num_machines: 1
gpu_ids: all
rdzv_backend: static
same_network: True
main_training_function: main
downcast_bf16: no
tpu_use_cluster: False
tpu_use_sudo: False
tpu_env: []

RecursionError: maximum recursion depth exceeded

Anyone has similar issue?

RecursionError Traceback (most recent call last)
Cell In[8], line 21
16 device = accelerator.device
18 print('Loading model..')
---> 21 model, image_processor, tokenizer = create_model_and_transforms(
22 clip_vision_encoder_path="ViT-L-14",
23 clip_vision_encoder_pretrained="openai",
24 #lang_encoder_path="huggyllama/llama-7b",
25 #tokenizer_path= "huggyllama/llama-7b",
26 lang_encoder_path="decapoda-research/llama-7b-hf",
27 tokenizer_path= "decapoda-research/llama-7b-hf",
28 cross_attn_every_n_layers=4,
29 )

Cell In[2], line 43, in create_model_and_transforms(clip_vision_encoder_path, clip_vision_encoder_pretrained, lang_encoder_path, tokenizer_path, cross_attn_every_n_layers, use_local_files, decoder_layers_attr_name, freeze_lm_embeddings, **flamingo_kwargs)
40 # set the vision encoder to output the visual features
41 vision_encoder.visual.output_tokens = True
---> 43 text_tokenizer = AutoTokenizer.from_pretrained(
44 tokenizer_path,
45 local_files_only=use_local_files,
46 trust_remote_code=True,
47 )
48 # add Flamingo special tokens to the tokenizer
49 text_tokenizer.add_special_tokens(
50 {"additional_special_tokens": ["<|endofchunk|>", ""]}
51 )

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py:736, in AutoTokenizer.from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
732 if tokenizer_class is None:
733 raise ValueError(
734 f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
735 )
--> 736 return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
738 # Otherwise we have to be creative.
739 # if model is an encoder decoder, the encoder tokenizer class is used by default
740 if isinstance(config, EncoderDecoderConfig):

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1854, in PreTrainedTokenizerBase.from_pretrained(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, *init_inputs, **kwargs)
1851 else:
1852 logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
-> 1854 return cls._from_pretrained(
1855 resolved_vocab_files,
1856 pretrained_model_name_or_path,
1857 init_configuration,
1858 *init_inputs,
1859 token=token,
1860 cache_dir=cache_dir,
1861 local_files_only=local_files_only,
1862 _commit_hash=commit_hash,
1863 _is_local=is_local,
1864 **kwargs,
1865 )

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:2017, in PreTrainedTokenizerBase._from_pretrained(cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, token, cache_dir, local_files_only, _commit_hash, _is_local, *init_inputs, **kwargs)
2015 # Instantiate tokenizer.
2016 try:
-> 2017 tokenizer = cls(*init_inputs, **init_kwargs)
2018 except OSError:
2019 raise OSError(
2020 "Unable to load vocabulary from file. "
2021 "Please check that the provided vocabulary is accessible and not corrupted."
2022 )

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/models/llama/tokenization_llama_fast.py:128, in LlamaTokenizerFast.init(self, vocab_file, tokenizer_file, clean_up_tokenization_spaces, unk_token, bos_token, eos_token, add_bos_token, add_eos_token, use_default_system_prompt, **kwargs)
126 self._add_bos_token = add_bos_token
127 self._add_eos_token = add_eos_token
--> 128 self.update_post_processor()
129 self.use_default_system_prompt = use_default_system_prompt
130 self.vocab_file = vocab_file

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/models/llama/tokenization_llama_fast.py:141, in LlamaTokenizerFast.update_post_processor(self)
137 """
138 Updates the underlying post processor with the current bos_token and eos_token.
139 """
140 bos = self.bos_token
--> 141 bos_token_id = self.bos_token_id
143 eos = self.eos_token
144 eos_token_id = self.eos_token_id

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1141, in SpecialTokensMixin.bos_token_id(self)
1139 if self._bos_token is None:
1140 return None
-> 1141 return self.convert_tokens_to_ids(self.bos_token)

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py:277, in PreTrainedTokenizerFast.convert_tokens_to_ids(self, tokens)
274 return None
276 if isinstance(tokens, str):
--> 277 return self._convert_token_to_id_with_added_voc(tokens)
279 return [self._convert_token_to_id_with_added_voc(token) for token in tokens]

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py:284, in PreTrainedTokenizerFast._convert_token_to_id_with_added_voc(self, token)
282 index = self._tokenizer.token_to_id(token)
283 if index is None:
--> 284 return self.unk_token_id
285 return index

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1160, in SpecialTokensMixin.unk_token_id(self)
1158 if self._unk_token is None:
1159 return None
-> 1160 return self.convert_tokens_to_ids(self.unk_token)

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py:277, in PreTrainedTokenizerFast.convert_tokens_to_ids(self, tokens)
274 return None
276 if isinstance(tokens, str):
--> 277 return self._convert_token_to_id_with_added_voc(tokens)
279 return [self._convert_token_to_id_with_added_voc(token) for token in tokens]

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py:284, in PreTrainedTokenizerFast._convert_token_to_id_with_added_voc(self, token)
282 index = self._tokenizer.token_to_id(token)
283 if index is None:
--> 284 return self.unk_token_id
285 return index

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1160, in SpecialTokensMixin.unk_token_id(self)
1158 if self._unk_token is None:
1159 return None
-> 1160 return self.convert_tokens_to_ids(self.unk_token)

[... skipping similar frames: PreTrainedTokenizerFast._convert_token_to_id_with_added_voc at line 284 (986 times), PreTrainedTokenizerFast.convert_tokens_to_ids at line 277 (986 times), SpecialTokensMixin.unk_token_id at line 1160 (985 times)]

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1160, in SpecialTokensMixin.unk_token_id(self)
1158 if self._unk_token is None:
1159 return None
-> 1160 return self.convert_tokens_to_ids(self.unk_token)

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py:277, in PreTrainedTokenizerFast.convert_tokens_to_ids(self, tokens)
274 return None
276 if isinstance(tokens, str):
--> 277 return self._convert_token_to_id_with_added_voc(tokens)
279 return [self._convert_token_to_id_with_added_voc(token) for token in tokens]

File ~/anaconda3/envs/dui/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py:284, in PreTrainedTokenizerFast._convert_token_to_id_with_added_voc(self, token)
282 index = self._tokenizer.token_to_id(token)
283 if index is None:
--> 284 return self.unk_token_id
285 return index

RecursionError: maximum recursion depth exceeded

GPU requirements (and other dependancies)

Hi

It would be great to include GPU RAM and storage requirements in the README. The requirements appear to be non-trivial for many.

Great work. Thank you for opening! It is much appreciated.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.