Coder Social home page Coder Social logo

medicinetoken / medical-sam2 Goto Github PK

View Code? Open in Web Editor NEW
335.0 335.0 43.0 2.11 MB

Medical SAM 2: Segment Medical Images As Video Via Segment Anything Model 2

License: Apache License 2.0

Python 97.45% Cuda 2.20% Shell 0.35%
deep-learning medical medical-imaging segment-anything segment-anything-2 segment-anything-model segmentation

medical-sam2's People

Contributors

jiayuanz3 avatar rabiaedayilmaz avatar wujunde avatar yunliqi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

medical-sam2's Issues

Design choice of trainable weights

Hi guys,
Thank you for your excellent work.

I am a bit confused about the selection of which parts of weights should be fine-tuned.
To clarify, I listed the trainable weights I observed for comparison:

  • In your previous work, Medical-SAM-Adapter
    • The trainable weights are adapter layers in the encoder and full mask decoder
  • In this work
    • For 2D training: the entire model is trainable
    • For 3D training: only the mask decoder and the memory-relative components are trainable

Could you please explain the rationale behind these decisions?

3D训练时loss的梯度更新和参数更新

您好,

想问下,目前代码中3d训练部分,是整个video计算平均loss然后梯度回传更新参数。是否可以每一帧计算loss进行梯度回传并更新参数,这样是否就可以把video的length从2变为整个video的长度?
涉及到memory相关的代码,应该如何修改呢?

期待并感谢您的回复!

RuntimeError: CUDA error: no kernel image is available for execution on the device Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

I encountered an error while installing the environment with the command conda env create -f environment.yml and running the 3D case。

Error:
Traceback (most recent call last):
File "/disk/projects/Medical-SAM2/train_3d.py", line 115, in
main()
File "/disk/projects/Medical-SAM2/train_3d.py", line 98, in main
loss, prompt_loss, non_prompt_loss = function.train_sam(args, net, optimizer1, optimizer2, nice_train_loader, epoch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk/projects/Medical-SAM2/func_3d/function.py", line 114, in train_sam
_, _, _ = net.train_add_new_bbox(
^^^^^^^^^^^^^^^^^^^^^^^
File "/disk/projects/Medical-SAM2/sam2_train/sam2_video_predictor.py", line 439, in train_add_new_bbox
out_frame_idx, out_obj_ids, out_mask_logits = self.train_add_new_points(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk/projects/Medical-SAM2/sam2_train/sam2_video_predictor.py", line 523, in train_add_new_points
current_out, _ = self._run_single_frame_inference(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk/projects/Medical-SAM2/sam2_train/sam2_video_predictor.py", line 1351, in _run_single_frame_inference
pred_masks_gpu = fill_holes_in_mask_scores(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk/projects/Medical-SAM2/sam2_train/utils/misc.py", line 255, in fill_holes_in_mask_scores
is_hole = (labels > 0) & (areas <= max_area)
^^^^^^^^^^
RuntimeError: CUDA error: no kernel image is available for execution on the device
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

System environment
lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description: Ubuntu 22.04.4 LTS
Release: 22.04
Codename: jammy
nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:19:38_PST_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0
conda -V
conda 23.5.0
nvidia-smi
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02 Driver Version: 550.107.02 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 Off | 00000000:01:00.0 Off | N/A |
| 35% 55C P0 N/A / 350W | 1MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+

Using pretrained weights on REFUGE dataset does not give good results

Hi, Thank you for sharing your code. I downloaded the pretrained weights (MedSAM2_pretrain.pth) from the link you provided and loaded them in the model and run the evaluation via file train_2d.py on the REFUGE dataset (also downloaded from your provided link) without any finetuning. I was hoping the pretrained model would give decent results so I could make sure the inference pipeline works ok. But I see very low numbers as give below:

INFO:root:Total score: 1.5592443943023682, IOU: 0.01751479435546341, DICE: 0.029827685931491824 || @ epoch 0.
Total score: 1.5592443943023682, IOU: 0.01751479435546341, DICE: 0.029827685931491824 || @ epoch 0.

For your reference, I slightly modify the code in train_2d.py to directly do the validation instead of training first epoch. The rest of the arguments used are as below:

Namespace(b=1, data_path='./dataset/REFUGE', dataset='REFUGE', distributed='none', encoder='vit_b', exp_name='REFUGE_MedSAM2', gpu=True, gpu_device=0, image_size=1024, lr=0.0001, memory_bank_size=16, multimask_output=1, net='sam2', out_size=1024, path_helper={'prefix': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47', 'ckpt_path': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47/Model', 'log_path': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47/Log', 'sample_path': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47/Samples'}, pretrain='MedSAM2_pretrain.pth', prompt='bbox', prompt_freq=2, sam_ckpt='./checkpoints/sam2_hiera_tiny.pt', sam_config='sam2_hiera_t', train_vis=False, val_freq=1, video_length=2, vis=True, weights=0)

Could you advise what could be the problem here ?

Thank you

The results that I produced using REFUGE_disc is worse than REFUGE_cup

To use the disc images in the REFUGE dataset, I modified these codes from:

img_path = os.path.join(subfolder, name + '_cropped.jpg')
multi_rater_cup_path = [os.path.join(subfolder, name + '_seg_cup_' + str(i) + '_cropped.jpg') for i in range(1, 8)]

to:

img_path = os.path.join(subfolder, name + '.jpg')
multi_rater_cup_path = [os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png') for i in range(1, 8)]

All other codes remained the same. I trained for 100 epochs. For cup in cropped image, I achieved DICE 0.866 and IoU 0.776, score of 0.125. While for disc images, I achieved DICE 0.825 and IoU 0.725, score of 0.028. I tried sam2_hiera_small and sam2_hiera_large, the results were similar.

There must be something wrong with my code modifications, or there may be fine-tuning steps that I missed. Could you provide some help on this?

Inference Example?

Hello,

Is there any example for inference (Notebook, python) available?

Thanks.

About video_length?

Thank you very much for your work! I would like to ask about the default settings in the train_3d cfg during the training phase. Is the default setting for video_length 2? Does it train only two adjacent images at a time?

Is there any available code for inference?

Hi there,

I'm currently working on a project to implement CT and MRI segmentation using this model. However, I've looked through the documentation and the examples provided in the repository, but I couldn't find a clear example of how to set up inference, especially for video segmentation.

Could you please let me know if there is any existing code or example notebooks for performing inference with the SAM2 model? Specifically, I'm looking for guidance on how to properly initialize the model components and run inference on video frames.

Thanks in advance for your help!

windows配置

请问windows下好运行吗,具体命令改如何配置呢

Cannot run 3d training

I encountered an error while trying to run training on the BTCV dataset you provided. The error message was: "ImportError: cannot import name '_C' from 'sam2_train'." The full error message is:

Traceback (most recent call last):
File "Medical-SAM2/train_3d.py", line 112, in
main()
File "Medical-SAM2/train_3d.py", line 95, in main
loss, prompt_loss, non_prompt_loss = function.train_sam(args, net, optimizer1, optimizer2, nice_train_loader, epoch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "Medical-SAM2/func_3d/function.py", line 114, in train_sam
_, _, _ = net.train_add_new_bbox(
^^^^^^^^^^^^^^^^^^^^^^^
File "Medical-SAM2/sam2_train/sam2_video_predictor.py", line 439, in train_add_new_bbox
out_frame_idx, out_obj_ids, out_mask_logits = self.train_add_new_points(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "Medical-SAM2/sam2_train/sam2_video_predictor.py", line 523, in train_add_new_points
current_out, _ = self._run_single_frame_inference(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "Medical-SAM2/sam2_train/sam2_video_predictor.py", line 1351, in _run_single_frame_inference
pred_masks_gpu = fill_holes_in_mask_scores(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "Medical-SAM2/sam2_train/utils/misc.py", line 254, in fill_holes_in_mask_scores
labels, areas = get_connected_components(mask <= 0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "Medical-SAM2/sam2_train/utils/misc.py", line 61, in get_connected_components
from sam2_train import _C
ImportError: cannot import name '_C' from 'sam2_train' (Medical-SAM2/sam2_train/init.py)

I’m having trouble understanding this error and would appreciate any help you can offer. Thank you so much!

Request the MedSAM 2 weights

Hi, could you release the trained weights of MedSAM 2? I noticed that you used SAM2-small. Have you tried the base and large versions? Can you release their weights?

fine-tune

How to fine-tune with your own dataset

training error

Hi, I have encountered the following error when running train_3d.py, could I know if you have any suggestions on it?

File "/data/humanBodyProject/repository_YC/Medical-SAM2/sam2_train/utils/misc.py", line 61, in get_connected_components
from sam2_train import _C
ImportError: /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32' not found (required by /data/humanBodyProject/repository_YC/Medical-SAM2/sam2_train/_C.so)

Checkpoints issue

Thank you for your work on the model!

When trying to run the checkpoint file with bash I have the following error:

$ cd checkpoints
bash download_ckpts.sh
Downloading sam2_hiera_tiny.pt checkpoint...
download_ckpts.sh: line 20: wget: command not found
Failed to download checkpoint from https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt

And this is the message when accessing the URL directly:

This XML file does not appear to have any style information associated with it. The document tree is shown below.

AccessDenied
Access Denied
CHG0JWSD8DDSWEEN
+Zia+nffTq4xbvpcCiWxhQc/R62zgPYXDa2qJcigatCMbCi3h/Wxo+6y0h89/T/OlEqQ/Fz33v7yR9ILjkhGxg==

Are the checkpoints required to run the train_3d.py with a decent DICE score?

installing requirements in kaggle

Hello,
I'm trying to install the requirements on Kaggle, but I'm encountering several issues.

Conda Installation: Installing the requirements using Conda takes an extremely long time, and the progress seems to be stuck with no visible advancement.
Has anyone successfully installed the requirements on Kaggle?
this is my code for the conda environment:

!conda create -n medsam2 python=3.12.4 -y > /dev/null
!source /opt/conda/bin/activate medsam2
!sudo rm /opt/conda/bin/python > /dev/null
!sudo ln -s /opt/conda/envs/medsam2/bin/python3 /opt/conda/bin/python > /dev/null
!sudo rm /opt/conda/bin/python3 > /dev/null
!sudo ln -sf /opt/conda/envs/medsam2/bin/python3 /opt/conda/bin/python3 > /dev/null

I tried to install them through pip directly ( using a requirements.txt i created from the yml file and even directly) but the default python version on kaggle is 3.10 (i tried changing that and even with the 3.12.4 version I encountered errors )
Here are some of the errors I get:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf 24.8.2 requires cubinlinker, which is not installed.
cudf 24.8.2 requires cupy-cuda11x>=12.0.0, which is not installed.
cudf 24.8.2 requires ptxcompiler, which is not installed.
cuml 24.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.
dask-cudf 24.8.2 requires cupy-cuda11x>=12.0.0, which is not installed.
ucxx 0.39.1 requires libucx>=1.15.0, which is not installed.
accelerate 0.33.0 requires numpy<2.0.0,>=1.17, but you have numpy 2.0.1 which is incompatible.
albucore 0.0.13 requires numpy<2,>=1.24.4, but you have numpy 2.0.1 which is incompatible.
apache-beam 2.46.0 requires cloudpickle~=2.2.1, but you have cloudpickle 3.0.0 which is incompatible.
apache-beam 2.46.0 requires dill<0.3.2,>=0.3.1.1, but you have dill 0.3.8 which is incompatible.
apache-beam 2.46.0 requires numpy<1.25.0,>=1.14.3, but you have numpy 2.0.1 which is incompatible.
apache-beam 2.46.0 requires protobuf<4,>3.12.2, but you have protobuf 4.25.4 which is incompatible.
apache-beam 2.46.0 requires pyarrow<10.0.0,>=3.0.0, but you have pyarrow 16.1.0 which is incompatible.
beatrix-jupyterlab 2024.66.154055 requires jupyterlab~=3.6.0, but you have jupyterlab 4.2.4 which is incompatible.
bigframes 0.22.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.10.0, but you have google-cloud-bigquery 2.34.4 which is incompatible.
bigframes 0.22.0 requires google-cloud-storage>=2.0.0, but you have google-cloud-storage 1.44.0 which is incompatible.
bigframes 0.22.0 requires pandas<2.1.4,>=1.5.0, but you have pandas 2.2.2 which is incompatible.
cudf 24.8.2 requires cuda-python<12.0a0,>=11.7.1, but you have cuda-python 12.6.0 which is incompatible.
cudf 24.8.2 requires numpy<2.0a0,>=1.23, but you have numpy 2.0.1 which is incompatible.
dask-cuda 24.8.2 requires numpy<2.0a0,>=1.23, but you have numpy 2.0.1 which is incompatible.
dask-cudf 24.8.2 requires numpy<2.0a0,>=1.23, but you have numpy 2.0.1 which is incompatible.
dataproc-jupyter-plugin 0.1.79 requires pydantic~=1.10.0, but you have pydantic 2.8.2 which is incompatible.
distributed 2024.7.1 requires dask==2024.7.1, but you have dask 2024.8.1 which is incompatible.
fitter 1.7.1 requires numpy<2.0.0,>=1.20.0, but you have numpy 2.0.1 which is incompatible.
gensim 4.3.3 requires numpy<2.0,>=1.18.5, but you have numpy 2.0.1 which is incompatible.
gensim 4.3.3 requires scipy<1.14.0,>=1.7.0, but you have scipy 1.14.0 which is incompatible.
google-cloud-aiplatform 0.6.0a1 requires google-api-core[grpc]<2.0.0dev,>=1.22.2, but you have google-api-core 2.11.1 which is incompatible.
google-cloud-automl 1.0.1 requires google-api-core[grpc]<2.0.0dev,>=1.14.0, but you have google-api-core 2.11.1 which is incompatible.
google-cloud-bigquery 2.34.4 requires packaging<22.0dev,>=14.3, but you have packaging 24.1 which is incompatible.
google-cloud-bigquery 2.34.4 requires protobuf<4.0.0dev,>=3.12.0, but you have protobuf 4.25.4 which is incompatible.
google-cloud-bigtable 1.7.3 requires protobuf<4.0.0dev, but you have protobuf 4.25.4 which is incompatible.
google-cloud-vision 2.8.0 requires protobuf<4.0.0dev,>=3.19.0, but you have protobuf 4.25.4 which is incompatible.
ibis-framework 7.1.0 requires numpy<2,>=1, but you have numpy 2.0.1 which is incompatible.
ibis-framework 7.1.0 requires pyarrow<15,>=2, but you have pyarrow 16.1.0 which is incompatible.
kfp 2.5.0 requires google-cloud-storage<3,>=2.2.1, but you have google-cloud-storage 1.44.0 which is incompatible.
kfp 2.5.0 requires protobuf<4,>=3.13.0, but you have protobuf 4.25.4 which is incompatible.
kfp-pipeline-spec 0.2.2 requires protobuf<4,>=3.13.0, but you have protobuf 4.25.4 which is incompatible.
libpysal 4.9.2 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.
momepy 0.7.2 requires shapely>=2, but you have shapely 1.8.5.post1 which is incompatible.
numba 0.58.1 requires numpy<1.27,>=1.22, but you have numpy 2.0.1 which is incompatible.
opentelemetry-exporter-otlp 1.25.0 requires opentelemetry-exporter-otlp-proto-http==1.25.0, but you have opentelemetry-exporter-otlp-proto-http 1.21.0 which is incompatible. .....

I would be really grateful if anyone could help me with this.

train_3d.py

I got an error when executing the train 3d.py script

The command and error information are as follows

(medsam2) (base) root@7a3a88562600:/workspace/vs-code/Medical-SAM2# python train_3d.py -net sam2 -exp_name BTCV_MedSAM2 -sam_ckpt ./checkpoints/sam2_hiera_small.pt -sam_config sam2_hiera_s -image_size 1024 -val_freq 1 -prompt bbox -prompt_freq 2 -dataset btcv -data_path ./data/btcv
/workspace/vs-code/Medical-SAM2/func_3d/function.py:41: FutureWarning: torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead.
scaler = torch.cuda.amp.GradScaler()
INFO:root:Namespace(net='sam2', encoder='vit_b', exp_name='BTCV_MedSAM2', vis=False, train_vis=False, prompt='bbox', prompt_freq=2, pretrain=None, val_freq=1, gpu=True, gpu_device=0, image_size=1024, out_size=1024, distributed='none', dataset='btcv', sam_ckpt='./checkpoints/sam2_hiera_small.pt', sam_config='sam2_hiera_s', video_length=2, b=1, lr=0.0001, weights=0, multimask_output=1, memory_bank_size=16, data_path='./data/btcv', path_helper={'prefix': 'logs/BTCV_MedSAM2_2024_08_26_12_45_59', 'ckpt_path': 'logs/BTCV_MedSAM2_2024_08_26_12_45_59/Model', 'log_path': 'logs/BTCV_MedSAM2_2024_08_26_12_45_59/Log', 'sample_path': 'logs/BTCV_MedSAM2_2024_08_26_12_45_59/Samples'})
Namespace(net='sam2', encoder='vit_b', exp_name='BTCV_MedSAM2', vis=False, train_vis=False, prompt='bbox', prompt_freq=2, pretrain=None, val_freq=1, gpu=True, gpu_device=0, image_size=1024, out_size=1024, distributed='none', dataset='btcv', sam_ckpt='./checkpoints/sam2_hiera_small.pt', sam_config='sam2_hiera_s', video_length=2, b=1, lr=0.0001, weights=0, multimask_output=1, memory_bank_size=16, data_path='./data/btcv', path_helper={'prefix': 'logs/BTCV_MedSAM2_2024_08_26_12_45_59', 'ckpt_path': 'logs/BTCV_MedSAM2_2024_08_26_12_45_59/Model', 'log_path': 'logs/BTCV_MedSAM2_2024_08_26_12_45_59/Log', 'sample_path': 'logs/BTCV_MedSAM2_2024_08_26_12_45_59/Samples'})
Epoch 0: 0%| | 0/24 [00:09<?, ?img/s]
Traceback (most recent call last):
File "/workspace/vs-code/Medical-SAM2/train_3d.py", line 114, in
main()
File "/workspace/vs-code/Medical-SAM2/train_3d.py", line 97, in main
loss, prompt_loss, non_prompt_loss = function.train_sam(args, net, optimizer1, optimizer2, nice_train_loader, epoch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/vs-code/Medical-SAM2/func_3d/function.py", line 114, in train_sam
_, _, _ = net.train_add_new_bbox(
^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/vs-code/Medical-SAM2/sam2_train/sam2_video_predictor.py", line 439, in train_add_new_bbox
out_frame_idx, out_obj_ids, out_mask_logits = self.train_add_new_points(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/vs-code/Medical-SAM2/sam2_train/sam2_video_predictor.py", line 523, in train_add_new_points
current_out, _ = self._run_single_frame_inference(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/vs-code/Medical-SAM2/sam2_train/sam2_video_predictor.py", line 1351, in _run_single_frame_inference
pred_masks_gpu = fill_holes_in_mask_scores(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/vs-code/Medical-SAM2/sam2_train/utils/misc.py", line 255, in fill_holes_in_mask_scores
is_hole = (labels > 0) & (areas <= max_area)
^^^^^^^^^^
RuntimeError: CUDA error: no kernel image is available for execution on the device
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Here's my cuda and torch info

image

Implementation of weighted pick up strategy and 'confidence-first'

While going through the code, I have two questions that I would really appreciate if you could help explain:

  1. In validation_sam() method inside the function.py file, I could not find the code that explicitly implements the weighted pickup strategy. Could you point out which part of the code takes care of that ?

  2. Also I do not see how the 'confidence-first' memory is implemented in the code ? Although I do see the memory bank being filled with diverse input features through similarity calculation.
    Thank you for your help in advance.

是否支持无标签预测?

在train.py文件中,训练模型时需要依赖于通过click或bbox生成的label,进而生成mask用于模型的训练。
我的工作需求是将脑的MRI影像中所有的部位全部分割出来,例如脑室区域、脑干区域、脑脊液。
在medical-sam2中,我使用任意患者数据进行预测的时候也必须通过click或bbox才能预测到结果吗?
是否可以全自动无标签的进行分割?如果不可以,我能否通过某个患者的一次bbox的结果mask,复制到其他患者的mask文件中对其他患者进行预测。
希望可以得到作者的回复,提前感谢!

Problems in click training.

Dear Authors,

Thank you for your prompt work on fine-tuning SAM2 for medical imaging. I am attempting to run your framework, but some bugs have occurred, taking a long time to debug😢
Hope you can complete your code as soon as possible, so that other works can have a fair comparison.

Thank you!

MaskDecoder resolution

Has the current mask decoder only a learnable output resolution of 256x256 and then only interpolated?

Are you not going to see a lot of interpolation artifacts going from 256x256 to 1024x1024 without any other intermediate learnable step/layer?

Unable to build environment using environment.yml file

I am unable to build the environment using the provided environment.yml file.
Operating System: Windows 11
Conda Version: 24.5.0
Python Version: 3.12.4

The process fails with the following error message after env creation command (conda env create -f environment.yml) in the Medical-SAM2 directory:
Channels:

  • conda-forge
  • defaults
    Platform: win-64
    Collecting package metadata (repodata.json): done
    Solving environment: failed
    Channels:
  • conda-forge
  • defaults
    Platform: win-64
    Collecting package metadata (repodata.json): done
    Solving environment: failed

LibMambaUnsatisfiableError: Encountered problems while solving:

  • nothing provides requested zlib ==1.2.13 h5eee18b_1
  • nothing provides requested zeromq ==4.3.5 h6a678d5_0
  • nothing provides requested xz ==5.4.6 h5eee18b_1
  • nothing provides requested wheel ==0.43.0 py312h06a4308_0
  • nothing provides requested tornado ==6.4.1 py312h5eee18b_0
  • nothing provides requested tk ==8.6.14 h39e8969_0
  • nothing provides requested sqlite ==3.45.3 h5eee18b_0
  • nothing provides requested setuptools ==69.5.1 py312h06a4308_0
  • nothing provides requested readline ==8.2 h5eee18b_0
  • nothing provides requested pyzmq ==25.1.2 py312h6a678d5_0
  • nothing provides requested python ==3.12.4 h5148396_1
  • nothing provides requested pip ==24.0 py312h06a4308_0
  • nothing provides requested openssl ==3.0.14 h5eee18b_0
  • nothing provides requested ncurses ==6.4 h6a678d5_0
  • nothing provides requested libuuid ==1.41.5 h5eee18b_0
  • nothing provides requested libstdcxx-ng ==11.2.0 h1234567_1
  • nothing provides requested libsodium ==1.0.18 h36c2ea0_1
  • nothing provides requested libgomp ==11.2.0 h1234567_1
  • nothing provides requested libgcc-ng ==11.2.0 h1234567_1
  • nothing provides requested libffi ==3.4.4 h6a678d5_1
  • nothing provides requested ld_impl_linux-64 ==2.38 h1181459_1
  • nothing provides requested jupyter_core ==5.7.2 py312h06a4308_0
  • nothing provides __unix needed by ipython-8.26.0-pyh707e725_0
  • nothing provides __linux needed by ipykernel-6.29.5-pyh3099207_0
  • nothing provides requested expat ==2.6.2 h6a678d5_0
  • nothing provides requested debugpy ==1.6.7 py312h6a678d5_0
  • nothing provides requested ca-certificates ==2024.7.4 hbcca054_0
  • nothing provides requested bzip2 ==1.0.8 h5eee18b_6
  • nothing provides requested _openmp_mutex ==5.1 1_gnu

Could not solve for environment specs
The following packages are incompatible
├─ _openmp_mutex ==5.1 1_gnu does not exist (perhaps a typo or a missing channel);
├─ bzip2 ==1.0.8 h5eee18b_6 does not exist (perhaps a typo or a missing channel);
├─ ca-certificates ==2024.7.4 hbcca054_0 does not exist (perhaps a typo or a missing channel);
├─ debugpy ==1.6.7 py312h6a678d5_0 does not exist (perhaps a typo or a missing channel);
├─ expat ==2.6.2 h6a678d5_0 does not exist (perhaps a typo or a missing channel);
├─ ipykernel ==6.29.5 pyh3099207_0 is not installable because it requires
│ └─ __linux, which is missing on the system;
├─ ipython ==8.26.0 pyh707e725_0 is not installable because it requires
│ └─ __unix, which is missing on the system;
├─ jupyter_core ==5.7.2 py312h06a4308_0 does not exist (perhaps a typo or a missing channel);
├─ ld_impl_linux-64 ==2.38 h1181459_1 does not exist (perhaps a typo or a missing channel);
├─ libffi ==3.4.4 h6a678d5_1 does not exist (perhaps a typo or a missing channel);
├─ libgcc-ng ==11.2.0 h1234567_1 does not exist (perhaps a typo or a missing channel);
├─ libgomp ==11.2.0 h1234567_1 does not exist (perhaps a typo or a missing channel);
├─ libsodium ==1.0.18 h36c2ea0_1 does not exist (perhaps a typo or a missing channel);
├─ libstdcxx-ng ==11.2.0 h1234567_1 does not exist (perhaps a typo or a missing channel);
├─ libuuid ==1.41.5 h5eee18b_0 does not exist (perhaps a typo or a missing channel);
├─ ncurses ==6.4 h6a678d5_0 does not exist (perhaps a typo or a missing channel);
├─ openssl ==3.0.14 h5eee18b_0 does not exist (perhaps a typo or a missing channel);
├─ readline ==8.2 h5eee18b_0 does not exist (perhaps a typo or a missing channel);
├─ setuptools ==69.5.1 py312h06a4308_0 does not exist (perhaps a typo or a missing channel);
├─ sqlite ==3.45.3 h5eee18b_0 does not exist (perhaps a typo or a missing channel);
├─ tk ==8.6.14 h39e8969_0 does not exist (perhaps a typo or a missing channel);
├─ tornado ==6.4.1 py312h5eee18b_0 does not exist (perhaps a typo or a missing channel);
├─ wheel ==0.43.0 py312h06a4308_0 does not exist (perhaps a typo or a missing channel);
├─ xz ==5.4.6 h5eee18b_1 does not exist (perhaps a typo or a missing channel);
├─ zeromq ==4.3.5 h6a678d5_0 does not exist (perhaps a typo or a missing channel);
└─ zlib ==1.2.13 h5eee18b_1 does not exist (perhaps a typo or a missing channel).

Is there code for segmentation visualization?

Hi,

I was looking through the README and noticed the impressive video segmentation visualization result in 3D Abdomen Segmentation Visualisation. Now I can use your code for inference, but I would like to visualize the segmentation result like the one shown in README.

I was wondering if there is any existing code in this project for visualizing segmentation results like this? If so, could you please point me to the relevant files or functions?

If not, could you provide some details on how this visualization was generated (e.g., tools or libraries used, how to overlay different segmentation regions on the original image, etc.)?

Thank you for your help!

video segmentaion?

Can Medsam2 automatically segment organs in endoscopic surgery videos?

one-shot inference

Hi, congratulations on this amazing project.

I was wondering what's the correct way to perform One-shot segmentation as described in the paper. Are there any examples?

predict

How do you make a direct prediction?

Support for One-Prompt segmentation

In function.py, I noticed that the validation_sam function is used for performing inference. However, I was wondering if there is a way to segment a new image without prompting and using instead a one-shot example.

Specifically, I am looking a One-prompt approach as described in the paper, where a query image is segmented based on a template image and its corresponding prompt. Could you provide guidance on how to implement this, or is there a built-in function that supports this use case?

inference error

Hello,

I encountered some errors while trying to run inference with the following command:

!python bctv_test.py -net sam2 -exp_name BTCV_MedSAM2 -sam_ckpt {base_path}/Medical-SAM2/sam2_hiera_small.pt -sam_config sam2_hiera_s.yaml -image_size 1024 -val_freq 1 -prompt bbox -prompt_freq 2 -dataset btcv -data_path ./data/btcv

The specific errors I encountered are as follows:
/kaggle/working/Medical-SAM2/sam2_train/modeling/sam/transformer.py:22: UserWarning: Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
Validation round: 0%| | 0/6 [00:00<?, ?batch/s]
/opt/conda/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
Traceback (most recent call last):
File "/kaggle/working/Medical-SAM2/func_3d/function.py", line 263, in validation_sam
bbox = bbox_dict[id][ann_obj_id]
KeyError: 2.0

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/kaggle/working/Medical-SAM2/bctv_test.py", line 52, in
main(args)
File "/kaggle/working/Medical-SAM2/bctv_test.py", line 38, in main
validation_loss, validation_metrics = validation_sam(args, nice_test_loader, 0, net)
File "/kaggle/working/Medical-SAM2/func_3d/function.py", line 273, in validation_sam
_, _, _ = net.train_add_new_mask(
File "/kaggle/working/Medical-SAM2/sam2_train/sam2_video_predictor.py", line 692, in train_add_new_mask
current_out, _ = self._run_single_frame_inference(
File "/kaggle/working/Medical-SAM2/sam2_train/sam2_video_predictor.py", line 1351, in _run_single_frame_inference
pred_masks_gpu = fill_holes_in_mask_scores(
File "/kaggle/working/Medical-SAM2/sam2_train/utils/misc.py", line 254, in fill_holes_in_mask_scores
labels, areas = get_connected_components(mask <= 0)
File "/kaggle/working/Medical-SAM2/sam2_train/utils/misc.py", line 61, in get_connected_components
from sam2_train import _C
ImportError: /kaggle/working/Medical-SAM2/sam2_train/_C.so: undefined symbol: _ZN3c106detail23torchInternalAssertFailEPKcS2_jS2_RKSs

I noticed that similar _C.so issues have been mentioned in the SAM2 repository. As I understand it, one possible solution involves recompiling the C++ extensions using a setup.py file to ensure compatibility with the current version of PyTorch. However, I couldn't find a setup.py file in your project.

could you please help me with this

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

While running 2D REFUGE data example training&validation command, which is python train_2d.py -net sam2 -exp_name REFUGE_MedSAM2 -vis 1 -sam_ckpt ./checkpoints/sam2_hiera_small.pt -sam_config sam2_hiera_s -image_size 1024 -out_size 1024 -b 4 -val_freq 1 -dataset REFUGE -data_path ./data/REFUGE, an error occurs.

Full Error Log:

/content/Medical-SAM2/sam2_train/modeling/sam/transformer.py:22: UserWarning: Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.
  OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
INFO:root:Namespace(net='sam2', encoder='vit_b', exp_name='REFUGE_MedSAM2', vis=True, train_vis=False, prompt='bbox', prompt_freq=2, pretrain=None, val_freq=1, gpu=True, gpu_device=0, image_size=1024, out_size=1024, distributed='none', dataset='REFUGE', sam_ckpt='./checkpoints/sam2_hiera_small.pt', sam_config='sam2_hiera_s', video_length=2, b=4, lr=0.0001, weights=0, multimask_output=1, memory_bank_size=16, data_path='./data/REFUGE', path_helper={'prefix': 'logs/REFUGE_MedSAM2_2024_09_07_12_17_24', 'ckpt_path': 'logs/REFUGE_MedSAM2_2024_09_07_12_17_24/Model', 'log_path': 'logs/REFUGE_MedSAM2_2024_09_07_12_17_24/Log', 'sample_path': 'logs/REFUGE_MedSAM2_2024_09_07_12_17_24/Samples'})
Namespace(net='sam2', encoder='vit_b', exp_name='REFUGE_MedSAM2', vis=True, train_vis=False, prompt='bbox', prompt_freq=2, pretrain=None, val_freq=1, gpu=True, gpu_device=0, image_size=1024, out_size=1024, distributed='none', dataset='REFUGE', sam_ckpt='./checkpoints/sam2_hiera_small.pt', sam_config='sam2_hiera_s', video_length=2, b=4, lr=0.0001, weights=0, multimask_output=1, memory_bank_size=16, data_path='./data/REFUGE', path_helper={'prefix': 'logs/REFUGE_MedSAM2_2024_09_07_12_17_24', 'ckpt_path': 'logs/REFUGE_MedSAM2_2024_09_07_12_17_24/Model', 'log_path': 'logs/REFUGE_MedSAM2_2024_09_07_12_17_24/Log', 'sample_path': 'logs/REFUGE_MedSAM2_2024_09_07_12_17_24/Samples'})
Traceback (most recent call last):
  File "/content/Medical-SAM2/train_2d.py", line 124, in <module>
    main()
  File "/content/Medical-SAM2/train_2d.py", line 97, in main
    tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/Medical-SAM2/func_2d/function.py", line 335, in validation_sam
    vision_feats_temp = vision_feats[-1].permute(1, 0, 2).view(B, -1, 64, 64) 
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
ERROR conda.cli.main_run:execute(124): conda run bash -c python train_2d.py -net sam2 -exp_name REFUGE_MedSAM2 -vis 1 -sam_ckpt ./checkpoints/sam2_hiera_small.pt -sam_config sam2_hiera_s -image_size 1024 -out_size 1024 -b 4 -val_freq 1 -dataset REFUGE -data_path ./data/REFUGE failed. (See above for error)

System Info:

Python version: 3.10.12 
conda 23.11.0
Ubuntu 22.04.3

The mentioned error was fixed in this PR: #36

Issue with train_3d.py - KeyError: 2 in function.py

I'm running the MedSAM2 code for the "3D case - Abdominal Multiple Organs Segmentation" following the provided instructions. After downloading the BTCV dataset and setting up the environment, I encountered an error when running the command:

python train_3d.py -net sam2 -exp_name BTCV_MedSAM2 -sam_ckpt ./checkpoints/sam2_hiera_small.pt -sam_config sam2_hiera_s -image_size 1024 -val_freq 1 -prompt bbox -prompt_freq 2 -dataset btcv -data_path ./data/btcv

The error occurs in function.py at the line:

bbox = bbox_dict[id][ann_obj_id]

with the following traceback:

KeyError: 2

It seems like the prompt for this slice wasn't passed correctly, leading to the KeyError. Could you please provide guidance on how to resolve this issue?

Thank you!

Remove personal path

Line 215 of environment.yml contains authors local filepath to conda env 'prefix: /home/leo/anaconda3/envs/medsam2'. This could cause an issue on user machines with connection to conda envs. Would recommend deleting.

prefix: /home/leo/anaconda3/envs/medsam2

Ben

Release of Medical-SAM2 pretrained weights

Hi, and thanks for the amazing work done on the paper! Just wondering if there are any plans to release the weights of Medical-SAM2 fine-tuned on medical imaging data as shown in the paper

Details about Experiment on STARE Vessel Dataset

Thank you for your remarkable work. MedSAM2 used a mask prompt on the STARE dataset, as shown in Figure 4 of the paper. I would like to know whether the mask covers the entire vessel region in one large area, or if it is detailed enough to cover each vessel individually with an approximate mask.
If you could guide me on how to perform testing and visualization on STARE, I would be extremely grateful.
im0001_2ndHO

请问如何在python3.10下编译运行train_3d.py

我在python3.10下运行train_3d.py得到了一个报错,
ImportError: Python version mismatch: module was compiled for Python 3.12, but the interpreter version is incompatible: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0].

修改环境是一件麻烦的事情,我想询问是否有快捷的方法,如果没有的话我会在python3.12上重新搭建环境。

Question about the 3D segmentation of organs

Hey :)
the paper looks really impressive and consider applying it in my workflow. To do this I have a quick question:

The Medical SAM 2 paper states in the abstract:

That allows users to provide a prompt for just one or a specific image targeting an object, after which the model can autonomously segment the same type of object in all subsequent images, regardless of temporal relationships between the images.

And later

In this scenario, users only need to provide a prompt for the first or a specific frame targeting an area, such as the optic cup in a fundus image. Subsequently, the model can autonomously segment the optic cup in subsequent images, even in the absence of temporal relationships between these images.

Does that mean the if I annotate some organ in the center only half of the organ is segmented?

This in turn would mean, that

  1. in order to segmented the organ I have to find the very first pixel of that organ in a specific direction,
    • which is difficult to annotate and
    • since this annotation is right at the border of the organ this segmentation is really difficult to do as the border pixel is not a typical pixel of the organ.
  2. Organs like the colon which go back and forth in the volume can not be segmented as the information is flowing only in one direction.

Am I missing something here?
Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.