This repository contains the code for training and inference of Iris-SAM.
- Download SAM's checkpoints
- Iris-SAM-checkpoints
First, we need to make a conda environment with the following command:
conda create -n iris-sam python=3.8
conda activate iris-sam
Then, we need to install the requirements:
pip install -r requirements.txt
The dataset is provided in data
directory along with images and masks. The images are in .jpg
format.
To train the model, run the following command:
python sam_train_losses.py --data_dir data --epochs 100 --loss focal --lr 0.0001 --save_dir checkpoints/ND
Model checkpoints will be saved in `save_dir` (default: `checkpoints/ft`). The program will save two checkpoints, one with optimizer and one without optimizer. The optimizer checkpoint is used for resuming training.
The checkpoint without optimizer (`model.pt`) is used for inference. We already provide a fine-tuned checkpoint in `google drive` and suggest to use model.pt.
We also provide a script to test the model on the val set. To test the model, run the following command:
python sam_test.py --data_dir data --pretrained_model *path_to_model/model.pt --save_dir *path_to_output/ND_test
To run inference, we need to provide bbox coordinates for the segmentation region in the image (automated). This is how SAM works,i.e. it requires a prompt along with image for prediction. In our case, the prompt is the bbox coordinates. We provide a script to run inference on a single image. To run inference, run the following command:
python sam_infer.py --image_path {PATH_TO_IMAGE_OR_DIR} --pretrained_model weights/model.pt --save_dir results
when you run the above command, an image window will be opened. You will be asked to draw (click) the top left and bottom right coordinates of the bbox using mouse pointer. Results will be saved in save_dir
(default: outputs/results
). The results include the overlay of predicted and ground truth masks.
The inference command can either take a single image or a directory containing multiple images as input. In case of directory, also provide the extension of the images. For example, if the images are in .jpg
format, run the following command:
python sam_infer.py --image_path {PATH_TO_IMAGE_DIR} -extension jpg --pretrained_model weights/model.pt --save_dir results
If you get errors with cv2 while plotting bboxes on the image, uninstall and install OpenCV as follows:
pip uninstall opencv-python-headless -y
pip install opencv-python --upgrade
If you find this repository useful, please consider giving a star โญ and a citation. (https://arxiv.org/abs/2402.06497).
@article{farmanifard2024iris,
title={Iris-SAM: Iris Segmentation Using a Foundation Model},
author={Farmanifard, Parisa and Ross, Arun},
journal={arXiv preprint arXiv:2402.06497},
year={2024}
}
- We thank the University of Salzburg and Halmstad University for providing the ground truth datasets.
- We thank Meta AI for making the source code of SAM publicly available.
- MedSAM.