Coder Social home page Coder Social logo

haochenheheda / segment-anything-annotator Goto Github PK

View Code? Open in Web Editor NEW
339.0 1.0 52.0 20.32 MB

We developed a python UI based on labelme and segment-anything for pixel-level annotation. It support multiple masks generation by SAM(box/point prompt), efficient polygon modification and category record. We will add more features (such as incorporating CLIP-based methods for category proposal and VOS methods for video datasets

License: GNU General Public License v3.0

Python 100.00%

segment-anything-annotator's Introduction

segment-anything-annotator

We developed a python UI based on labelme and segment-anything for pixel-level annotation. It support generating multiple masks by SAM(box/point prompt), efficient polygon modification and category record. We will add more features (such as incorporating CLIP-based methods for category proposal and VOS methods for mask association of video datasets)

Any feedback or suggestions are welcomed. We will continuously add features and fix bugs. 👀👀👀

News

28 Apr: Change the output format with labelme format. If you want to use the old output format, please use backup/annotator.py.

21 Apr: Add annotation script for video dataset. See Video Usage for the details. VIdeo_Demo

Features

  • Interactive Segmentation by SAM (both boxes and points prompt)
  • Multiple Output Choices
  • Category Annotation
  • Polygon modification
  • CLIP for Category Proposal
  • STCN for Video Dataset Annotation

Demo

Demo

Installation

  1. Python>=3.8
  2. Pytorch
  3. pip install -r requirements.txt

Usage

1. Start the Annotation Platform

python annotator.py --app_resolution 1000,1600 --model_type vit_b --keep_input_size True --max_size 720

--model_type: vit_b, vit_l, vit_h

--keep_input_size: True: keep the origin image size for SAM; False: resize the input image to --max_size (save GPU memory)

2. Load the category list file if you want to annotate object categories.

Click the Category File on the top tool bar and choose your own one, such as the categories.txt in this repo.

3. Specify the image and save folds

Click the 'Image Directory' on the top tool bar to specify the fold containing images (in .jpg or .png). Click the 'Save Directory' on the top tool bar to specify the fold for saving the annotations. The annotations of each image will be saved as json file in the following format

[
  #object1
  {
      'label':<category>, 
      'group_id':<id>,
      'shape_type':'polygon',
      'points':[[x1,y1],[x2,y2],[x3,y3],...]
  },
  #object2
  ...
]

4. Load SAM model

Click the "Load SAM" on the top tool bar to load the SAM model. The model will be automatically downloaded at the first time. Please be patient. Or you can manually download the models and put them in the root directory named vit_b.pth, vit_l.pth and vit_h.pth.

5. Annotating Functions

Manual Polygons: manually add masks by clicking on the boundary of the objects, just like the Labelme (Press right button and drag to draw the arcs easily).

Point Prompt: generate mask proposals with clicks. The mouse leftpress/rightpress represent positive/negative clicks respectively. You can see several mask proposals below in the boxes: Proposal1-4, and you could choose one by clicking or shortcuts 1,2,3,4.

Box Prompt: generate mask proposals with boxes.

Accept(shortcut:a): accept the chosen proposal and add to the annotation dock.

Reject(shortcut:r): reject the proposals and clean the workspace.

Save(shortcut:'s'): save annotations to file. Do not forget to save your annotation for each image, or it will be lost when you switch to the next image.

Edit Polygons: in this mode, you could modify the annotated objects, such as changing the category labels or ids by double click on object items in the annotation dock. And you can modify the boundary by draging the points on the boundary.

Delete(shortcut:'d'): under Edit Mode, delete selected/hightlight objects from annotation dock.

Reduce Point: under Edit Mode, if you find the polygon is too dense to edit, you could use this button to reduce the points on the selected polygons. But this will slightly reduce the annotation quality.

Zoom in/out: press 'CTRL' and scroll wheel on the mouse

Class On/Off: if the Class is turned on, a dialog will show after you accept a mask to record category and id, or the catgeory will be default value "Object".

Video Usage

1. clone STCN, download the stcn.pth and put them in the root directory like this:

-| segment-anything-annotator
    -| annotation_video.py
    .....
    -| STCN
    -| stcn.pth

2. Start the Annotation Platform

python annotator_video.py --app_resolution 1000,1600 --model_type vit_b --keep_input_size True --max_size 720 --max_size_STCN 600

--model_type: vit_b, vit_l, vit_h --keep_input_size: True: keep the origin image size for SAM; False: resize the input image to --max_size (save GPU memory) --max_size_STCN: the maximum input image size for STCN (don't be too large)

3. Specify the video and save folds

Click 'Video Directory' and 'Save Directory'. The folds containing videos should be structured like this:

-| video_fold
    -| video1
        -| 00000.jpg
        -| 00001.jpg
        ...
    -| video2
        -| 00000.jpg
        -| 00001.jpg     
    ...

3. Load STCN and SAM

Click 'Load STCN' and 'Load SAM'.

4. Video Annotating Functions

(a) Finish the annotations of the first frame with SAM

(b) Press and hold left control, then press left mouse button to select the objects you want to track (should be highlighted by colors)

(c) Click add object to memory to initialize the tracklets

(d) move to next frame, and click Propagate to obtain tracked masks on new frame. (the results will be automatically saved when you change frames)

(e) if the propagated masks are not good enough, press 'e' to enter Edit mode, then you could manually correct the masks. You could also use Add as key frame to add a certain frame as reference to improve the propagation stability.

ShortCuts: b: last frame, n: next frame, e: edit model, a: accept proposal, r: reject proposal, d: delete, s: save, space: propagate

To Do

  • CLIP for Category Proposal
  • STCN for Video Dataset Annotation
  • Fix bugs and optimize the UI
  • Annotation Files -> Labelme Format/COCO Format/Youtube-VIS Format

Acknowledgement

This repo is built on SAM and Labelme.

segment-anything-annotator's People

Contributors

haochenheheda avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

segment-anything-annotator's Issues

[Feature Request] Adjust image size in the polygon-editing window

6AABFF04-1B1F-4d2b-AF48-95D494313963
As shown in the image above. When labeling a large image (in my case, a 2000x2000 px image), the polygon-editing window seems to only display the left-top part (about 700x1000 px) of the orginal image. This limitation really slows down the labeling progress, because I have to crop my image into many 700x700 images and then stitch them together.

Is there a way to make the window larger?

The point prompt is giving arbitrary mask

Hi,
thanks for your greak work! It's really cool to facilitate the annotation process by SAM. However, I come across some problem using the tools which the prompt function seems to be not working when after labeling a few images or switching to another bag of images.
The screenshot is provided under which as you can see I'm trying to label the fly birds in the scene but it gives some nonsense proposal and I have to quit and start over again. Not sure why is that happen, can you help me on this?
image

How to resize the category choose widget

Amazing work!

How should I lengthen the height of the category choose dialog, so that it can intuitively display all candidate categories without operating the mouse wheel, please?

image

Application fails to load within an anaconda virtual env in Ubuntu 20.04 [Opencv and qt error]

Hi,

I believe that due to an opencv error, your application is failing to load when using a virtual environment and the requirements.txt file that you provided. After googling for a little bit a found a solution which I also share in this issue to help others who encounter this error. Might be good to improve the requirements.txt to include the right versions of opencv.

To get to the error, I created and installed all the requirements I lunch the application with

python annotator.py --app_resolution 1000,1600 --model_type vit_h --keep_input_size True --max_size 720

which produced the following error

QObject::moveToThread: Current thread (0x55ce9ec4a6b0) is not the object's thread (0x55cea1a8a450).
Cannot move to target thread (0x55ce9ec4a6b0)

qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "/home/juan1995/anaconda3/envs/sam/lib/python3.9/site-packages/cv2/qt/plugins" even though it was found.
This application failed to start because no Qt platform plugin could be initialized. Reinstalling the application may fix this problem.

Available platform plugins are: xcb, eglfs, linuxfb, minimal, minimalegl, offscreen, vnc, wayland-egl, wayland, wayland-xcomposite-egl, wayland-xcomposite-glx, webgl.

Setup to reproduce error

  • machine: Ubuntu 20.04
    First, I created a new anaconda virtual environment with
conda create -n sam python=3.9 -y
pip install -r requirements.txt

Then within conda environment execute

python annotator.py --app_resolution 1000,1600 --model_type vit_h --keep_input_size True --max_size 720

Solution

As indicated in this post, the error can be solved by

pip uninstall opencv-python
pip install opencv-python-headless

I got the annotation tool working after doing this.

视频功能导入文件就崩溃

Traceback (most recent call last):
File "annotator_video.py", line 309, in
lambda: self.clickSaveChoose(),
File "annotator_video.py", line 746, in clickSaveChoose
self.loadImg()
File "annotator_video.py", line 697, in loadImg
self.raw_h, self.raw_w = cv2.imread(self.current_img).shape[:2]
AttributeError: 'NoneType' object has no attribute 'shape'
已放弃 (核心已转储)

Lisence

Hello Haochenheheda,

thank you very much for making your Labeling-Software available. I would like to use it.
Under which license does this fall? Is it free to use for everyone?

Kind regards, Mario Schweda

Screen Resolution not matching

Hello! I am using the segment-anything-annotator code on a MacBook Air (M2 chip). for some reason, the default resolution in the annotator.py file is resulting in skewed screens hence not allowing me to be able to see certain parts of the GUI window. I tried adjusting the resolution but it's either always too less or too much. I have attached a screenshot (with a loaded image)
Screenshot 2023-12-15 at 18 03 41
for reference of the GUI at the default resolution. could you please tell what the issue could be?

Reduce point bug

image

有些标注的mask 可以,删减
有的点了没有反应,不能删减
很不稳定

QT-related error when using in docker

I am using the application in nvcr.io/nvidia/pytorch:23.04-py3, a docker image provided by nvidia. And I have installed dependencies as in requirements.txt. But I still get this error:

qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "" even though it was found.
This application failed to start because no Qt platform plugin could be initialized. Reinstalling the application may fix this problem.

Available platform plugins are: eglfs, linuxfb, minimal, minimalegl, offscreen, vnc, wayland-egl, wayland, wayland-xcomposite-egl, wayland-xcomposite-glx, webgl, xcb.

Aborted (core dumped)

Referring to this, I use export QT_DEBUG_PLUGINS=1 to enable debugging mode for qt, and get the following message:

Got keys from plugin meta data ("xcb")
QFactoryLoader::QFactoryLoader() checking directory path "/usr/bin/platforms" ...
Cannot load library /usr/local/lib/python3.8/dist-packages/PyQt5/Qt5/plugins/platforms/libqxcb.so: (libdbus-1.so.3: cannot open shared object file: No such file or directory)
QLibraryPrivate::loadPlugin failed on "/usr/local/lib/python3.8/dist-packages/PyQt5/Qt5/plugins/platforms/libqxcb.so" : "Cannot load library /usr/local/lib/python3.8/dist-packages/PyQt5/Qt5/plugins/platforms/libqxcb.so: (libdbus-1.so.3: cannot open shared object file: No such file or directory)"
qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "" even though it was found.
This application failed to start because no Qt platform plugin could be initialized. Reinstalling the application may fix this problem.

Available platform plugins are: eglfs, linuxfb, minimal, minimalegl, offscreen, vnc, wayland-egl, wayland, wayland-xcomposite-egl, wayland-xcomposite-glx, webgl, xcb.

Aborted (core dumped)

It reveals what dependency is missing. Install the corresponding packages solves the problem. (There could be more than 1 missing. Do it repeatedly until it works fine)

labelme格式

对源代码做了一些变动,以期实现labelme格式的输出
增加了一个函数,lesscnt,实现每次SAM识别后的多边形点数自动减少,重写了loadAnno函数、format_shape函数,实现输出labelme格式。
更改后似乎更改output目录没什么意义了,这个按钮感觉可以取消
其他变动在代码上都著有#change

import sys
import functools
import cv2
import glob
import os
import os.path as osp
import imgviz
import html
import json
import math
import argparse
import numpy as np
import tempfile
import torch

from PyQt5.QtWidgets import QWidget, QApplication, QMainWindow, QApplication, QPushButton, QLabel, QFileDialog, QProgressBar, QComboBox, QScrollArea, QDockWidget, QMessageBox
from PyQt5.QtGui import QPixmap, QIcon, QImage
from PyQt5.Qt import QSize
from qtpy.QtCore import Qt
from qtpy import QtCore
from qtpy import QtGui, QtWidgets
from canvas import Canvas
import utils
from utils.download_model import download_model

from labelme.widgets import ToolBar, UniqueLabelQListWidget, LabelDialog, LabelListWidget, LabelListWidgetItem, ZoomWidget
from labelme import PY2
from labelme.label_file import LabelFile
from labelme.label_file import LabelFileError


from shape import Shape

from PIL import Image

from collections import namedtuple
Click = namedtuple('Click', ['is_positive', 'coords'])

from segment_anything import sam_model_registry, SamPredictor

#add
from PyQt5.QtWidgets import QVBoxLayout


LABEL_COLORMAP = imgviz.label_colormap()

class MainWindow(QMainWindow):

    FIT_WINDOW, FIT_WIDTH, MANUAL_ZOOM = 0, 1, 2

    def __init__(self, parent=None, global_w=1000, global_h=1800, model_type='vit_b', keep_input_size=True, max_size=1080):
        super(MainWindow, self).__init__(parent)
        self.resize(global_w, global_h)
        self.model_type = model_type
        self.keep_input_size = keep_input_size
        self.max_size = float(max_size)

        self.setWindowTitle('segment-anything-annotator')
        self.canvas = Canvas(self,
            epsilon=10.0,
            double_click='close',
            num_backups=10,
            app=self,
        )
        
        self._noSelectionSlot = False
        self.current_output_dir = 'output'
        os.makedirs(self.current_output_dir, exist_ok=True)
        self.current_output_filename = ''
        self.canvas.zoomRequest.connect(self.zoomRequest)

        self.memory_shapes = []
        self.sam_mask = []
        self.sam_mask_proposal = []
        self.image_encoded_flag = False
        self.min_point_dis = 4

        self.predictor = None

        self.scroll_values = {
            Qt.Horizontal: {},
            Qt.Vertical: {},
        }
        self.scrollArea = QScrollArea(self)
        
        self.scrollArea.setWidget(self.canvas)
        self.scrollArea.setWidgetResizable(True)
        self.scrollBars = {
            Qt.Vertical: self.scrollArea.verticalScrollBar(),
            Qt.Horizontal: self.scrollArea.horizontalScrollBar(),
        }
        self.canvas.scrollRequest.connect(self.scrollRequest)
        self.canvas.newShape.connect(self.newShape)
        self.canvas.shapeMoved.connect(self.setDirty)
        self.canvas.selectionChanged.connect(self.shapeSelectionChanged)
        self.canvas.drawingPolygon.connect(self.toggleDrawingSensitive)

        self.uniqLabelList = UniqueLabelQListWidget()
        self.uniqLabelList.setToolTip(
            self.tr(
                "Select label to start annotating for it. "
                "Press 'Esc' to deselect."
            )
        )
        self.labelDialog = LabelDialog(
            parent=self,
            labels=[],
            sort_labels=False,
            show_text_field=True,
            completion='contains',
            fit_to_content={'column': True, 'row': False},
        )

        self.labelList = LabelListWidget()
        self.labelList.itemSelectionChanged.connect(self.labelSelectionChanged)
        self.labelList.itemDoubleClicked.connect(self.editLabel)
        self.labelList.itemChanged.connect(self.labelItemChanged)
        self.labelList.itemDropped.connect(self.labelOrderChanged)

        self.shape_dock = QDockWidget(
            self.tr("Polygon Labels"), self
        )
        self.shape_dock.setObjectName("Labels")
        self.shape_dock.setWidget(self.labelList)

        self.category_list = [i.strip() for i in open('C2.txt', 'r', encoding='utf-8').readlines()]
        self.labelDialog = LabelDialog(
            parent=self,
            labels=self.category_list,
            sort_labels=False,
            show_text_field=True,
            completion='contains',
            fit_to_content={'column': True, 'row': False},
        )
        self.zoom_values = {}
        self.video_directory = ''
        self.video_list = []
        self.video_len = len(self.video_list)

        self.img_list = []
        self.img_len = len(self.img_list)
        self.current_img_index = 0
        self.current_img = ''

        self.button_next = QPushButton('Next Image', self)
        self.button_next.clicked.connect(self.clickButtonNext)
        self.button_last = QPushButton('Last Image', self)
        self.button_last.clicked.connect(self.clickButtonLast)

        self.img_progress_bar = QProgressBar(self)
        self.img_progress_bar.setMinimum(0)
        self.img_progress_bar.setMaximum(1)
        self.img_progress_bar.setValue(0)
        #add
        #layout1 = QVBoxLayout()
        self.button_proposal1 = QPushButton('Proposal1', self)
        self.button_proposal1.clicked.connect(self.choose_proposal1)
        self.button_proposal1.setShortcut('1')
        self.button_proposal2 = QPushButton('Proposal2', self)
        self.button_proposal2.clicked.connect(self.choose_proposal2)
        self.button_proposal2.setShortcut('2')
        self.button_proposal3 = QPushButton('Proposal3', self)
        self.button_proposal3.clicked.connect(self.choose_proposal3)
        self.button_proposal3.setShortcut('3')
        self.button_proposal4 = QPushButton('Proposal4', self)
        self.button_proposal4.clicked.connect(self.choose_proposal4)
        self.button_proposal4.setShortcut('4')
        self.button_proposal_list = [self.button_proposal1, self.button_proposal2, self.button_proposal3, self.button_proposal4]
        #add
        #layout1.addWidget(self.button_proposal1 )
        #layout1.addWidget(self.button_proposal2 )
        #layout1.addWidget(self.button_proposal3 )
        #layout1.addWidget(self.button_proposal4 )
        #self.setLayout(layout1)
        
        
        self.class_on_flag = True
        self.class_on_text = QLabel("Class On", self)
        

        #naive layout
        self.scrollArea.move(int(0.02 * global_w), int(0.08 * global_h))
        self.scrollArea.resize(int(0.75 * global_w), int(0.7 * global_h))
        self.shape_dock.move(int(0.79 * global_w), int(0.08 * global_h))
        self.shape_dock.resize(int(0.2 * global_w), int(0.7 * global_h))
        self.button_next.move(int(0.18 * global_w), int(0.85 * global_h))
        self.button_next.resize(int(0.1 * global_w),int(0.04 * global_h))
        self.button_last.move(int(0.01 * global_w), int(0.85 * global_h))
        self.button_last.resize(int(0.1 * global_w),int(0.04 * global_h))
        self.class_on_text.move(int(0.01 * global_w), int(0.9 * global_h))
        self.img_progress_bar.move(int(0.01 * global_w), int(0.8 * global_h))
        self.img_progress_bar.resize(int(0.3 * global_w),int(0.04 * global_h))
        
        self.button_proposal1.resize(int(0.17 * global_w),int(0.14 * global_h))
        self.button_proposal1.move(int(0.33 * global_w), int(0.8 * global_h))
        self.button_proposal2.resize(int(0.17 * global_w),int(0.14 * global_h))
        self.button_proposal2.move(int(0.50 * global_w), int(0.8 * global_h))
        self.button_proposal3.resize(int(0.17 * global_w),int(0.14 * global_h))
        self.button_proposal3.move(int(0.67 * global_w), int(0.8 * global_h))
        self.button_proposal4.resize(int(0.17 * global_w),int(0.14 * global_h))
        self.button_proposal4.move(int(0.84 * global_w), int(0.8 * global_h))
        
        
        
        self.zoomWidget = ZoomWidget()

        action = functools.partial(utils.newAction, self)
        

        categoryFile = action(
            self.tr("Category File"),
            lambda: self.clickCategoryChoose(),
            'None',
            "objects",
            self.tr("Category File"),
            enabled=True,
        )
        imageDirectory = action(
            self.tr("Image Directory"),
            lambda: self.clickFileChoose(),
            'None',
            "objects",
            self.tr("Image Directory"),
            enabled=True,
        )
        LoadSAM = action(
            self.tr("Load SAM"),
            lambda: self.clickLoadSAM(),
            'None',
            "objects",
            self.tr("Load SAM"),
            enabled=True,
        )
        AutoSeg = action(
            self.tr("AutoSeg"),
            lambda: self.clickAutoSeg(),
            'None',
            "objects",
            self.tr("AutoSeg"),
            enabled=False,
        )
        promptSeg = action(
            self.tr("Accept"),
            lambda: self.addSamMask(),
            'a',
            "objects",
            self.tr("Accept"),
            enabled=False,
        )

        saveDirectory = action(
            self.tr("Save Directory"),
            lambda: self.clickSaveChoose(),
            'None',
            "objects",
            self.tr("Save Directory"),
            enabled=True,
        )

        createMode = action(
            self.tr("Manual Polygons"),
            lambda: self.toggleDrawMode(False, createMode="polygon"),
            'Ctrl+W',
            "objects",
            self.tr("Start drawing polygons"),
            enabled=True,
        )
        createPointMode = action(
            self.tr("Point Prompt"),
            lambda: self.toggleDrawMode(False, createMode="point"),
            'None',
            "objects",
            self.tr("Point Prompt"),
            enabled=True,
        )
        createRectangleMode = action(
            self.tr("Box Prompt"),
            lambda: self.toggleDrawMode(False, createMode="rectangle"),
            'None',
            "objects",
            self.tr("Box Prompt"),
            enabled=True,
        )
        cleanPrompt = action(
            self.tr("Reject"),
            lambda: self.cleanPrompt(),
            'r',
            "objects",
            self.tr("Reject"),
            enabled=True,
        )
        
        self.switchClass = action(
            self.tr("Class On/Off"),
            lambda: self.clickSwitchClass(),
            'none',
            "objects",
            self.tr("Class On/Off"),
            enabled=True,
        )

        editMode = action(
            self.tr("Edit Polygons"),
            self.setEditMode,
            'None',
            "edit",
            self.tr("Move and edit the selected polygons"),
            enabled=False,
        )
        saveAs = action(
            self.tr("&Save As"),
            self.saveFileAs,
            'ALT+s',
            "save-as",
            self.tr("Save labels to a different file"),
            enabled=True,
        )

        undoLastPoint = action(
            self.tr("Undo last point"),
            self.canvas.undoLastPoint,
            'U',
            "undo",
            self.tr("Undo last drawn point"),
            enabled=False,
        )

        hideAll = action(
            self.tr("&Hide\nPolygons"),
            functools.partial(self.togglePolygons, False),
            icon="eye",
            tip=self.tr("Hide all polygons"),
            enabled=False,
        )
        showAll = action(
            self.tr("&Show\nPolygons"),
            functools.partial(self.togglePolygons, True),
            icon="eye",
            tip=self.tr("Show all polygons"),
            enabled=False,
        )

        undo = action(
            self.tr("Undo"),
            self.undoShapeEdit,
            'Ctrl+U',
            "undo",
            self.tr("Undo last add and edit of shape"),
            enabled=False,
        )

        save = action(
            self.tr("&Save"),
            self.saveFile,
            'S',
            "save",
            self.tr("Save labels to file"),
            enabled=False,
        )

        delete = action(
            self.tr("Delete Polygons"),
            self.deleteSelectedShape,
            'd',
            "cancel",
            self.tr("Delete the selected polygons"),
            enabled=False,
        )
        duplicate = action(
            self.tr("Duplicate Polygons"),
            self.duplicateSelectedShape,
            'None',
            "copy",
            self.tr("Create a duplicate of the selected polygons"),
            enabled=False,
        )
        reduce_point = action(
            self.tr("Reduce Points"),
            self.reducePoint,
            'None',
            "copy",
            self.tr("Reduce Points"),
            enabled=True,
        )            
        edit = action(
            self.tr("&Edit Label"),
            self.editLabel,
            'None',
            "edit",
            self.tr("Modify the label of the selected polygon"),
            enabled=False,
        )
        

        self.actions = utils.struct(
            categoryFile=categoryFile,
            imageDirectory=imageDirectory,
            saveDirectory=saveDirectory,
            switchClass=self.switchClass,
            loadSAM=LoadSAM,
            #autoSeg=AutoSeg,
            promptSeg=promptSeg,
            cleanPrompt=cleanPrompt,
            createMode=createMode,
            createPointMode=createPointMode,
            createRectangleMode=createRectangleMode,
            editMode=editMode,
            undoLastPoint=undoLastPoint,
            undo=undo,
            delete=delete,
            edit=edit,
            duplicate=duplicate,
            reduce_point=reduce_point,
            save=save,
            onShapesPresent=(saveAs, hideAll, showAll),
            menu=(
                createMode,
                editMode,
                undoLastPoint,
                undo,
                save,
            )
            )

        # Custom context menu for the canvas widget:
        utils.addActions(self.canvas.menus[0], self.actions.menu)
        utils.addActions(
            self.canvas.menus[1],
            (
                action("&Copy here", self.copyShape),
                action("&Move here", self.moveShape),
            ),
        )

        self.toolbar = self.addToolBar('Tool')
        self.toolbar.addAction(categoryFile)
        self.toolbar.addAction(imageDirectory)
        self.toolbar.addAction(saveDirectory)
        self.toolbar.addAction(self.switchClass)
        self.toolbar.addAction(LoadSAM)
        #self.toolbar.addAction(AutoSeg)
        self.toolbar.addAction(promptSeg)
        self.toolbar.addAction(cleanPrompt)
        self.toolbar.addAction(createMode)
        self.toolbar.addAction(createPointMode)
        self.toolbar.addAction(createRectangleMode)
        self.toolbar.addAction(editMode)
        self.toolbar.addAction(undoLastPoint)
        self.toolbar.addAction(undo)
        self.toolbar.addAction(delete)
        self.toolbar.addAction(edit)
        self.toolbar.addAction(duplicate)
        self.toolbar.addAction(reduce_point)
        self.toolbar.addAction(save)
        self.toolbar.setToolButtonStyle(Qt.ToolButtonTextOnly)

        zoom = QtWidgets.QWidgetAction(self)
        zoom.setDefaultWidget(self.zoomWidget)
        self.zoomWidget.setWhatsThis(
            str(
                self.tr(
                    "Zoom in or out of the image. Also accessible with "
                    "{} from the canvas."
                )
            ).format(
                #utils.fmtShortcut(
                #    "{},{}".format(shortcuts["zoom_in"], shortcuts["zoom_out"])
                #),
                utils.fmtShortcut(self.tr("Ctrl+Wheel")),
            )
        )
        self.zoomWidget.setEnabled(True)

        self.zoomWidget.valueChanged.connect(self.paintCanvas)
        self.canvas.actions = self.actions


    def saveFileAs(self, _value=False):
        assert not self.image.isNull(), "cannot save empty image"
        self._saveFile(self.saveFileDialog())

    def saveFile(self, _value=False):
        # assert not self.image.isNull(), "cannot save empty image"
        # if self.labelFile:
        #     # DL20180323 - overwrite when in directory
        #     self._saveFile(self.labelFile.filename)
        # elif self.output_file:
        #     self._saveFile(self.output_file)
        #     self.close()
        # else:
        #     self._saveFile(self.saveFileDialog())
        #self._saveFile(self.saveFileDialog())
        #print(self.current_output_filename)
        self._saveFile(self.current_output_filename)

    def _saveFile(self, filename):
        if filename and self.saveLabels(filename):
            self.setClean()
    
    def saveLabels(self, filename):
        lf = LabelFile()

        def format_shape(s):
            data = s.other_data.copy()
            data.update(
                dict(
                    label=s.label.encode("utf-8") if PY2 else s.label,
                    points=[(p.x(), p.y()) for p in s.points],
                    group_id=s.group_id,
                    shape_type=s.shape_type,
                    flags=s.flags,
                    #change
                    line_color = None,
                    fill_color = None
                )
            )
            return data
        #change
        def format_shape2(s):
            data=dict(
                label =(s.label.encode("utf-8")+f"_{s.group_id}") if PY2 else (s.label+f"_{s.group_id}"),
                line_color = None,
                fill_color = None,
                points=[(p.x(), p.y()) for p in s.points],
                shape_type=s.shape_type,
                flags={}
            )
            return data
        shapes = [format_shape2(item.shape()) for item in self.labelList]
        labelmedict = dict(
            version = "3.16.7",
            flags = {},
            shapes = shapes,
            lineColor = [0,255,0,128],
            fillColor = [0,255,0,128],
            imagePath = os.path.splitext(os.path.basename(filename))[0]+".jpg",
            imageData = None,
            imageHeight = None,
            imageWidth = None
        ) 
        with open(filename, 'w') as f:
            json.dump(labelmedict, f,indent =4)
        return True

    def setClean(self):
        self.dirty = False
        self.actions.save.setEnabled(False)
        self.actions.createMode.setEnabled(True)

    def saveFileDialog(self):
        caption = self.tr("Choose File")
        filters = self.tr("Label files")
        if self.output_dir:
            dlg = QtWidgets.QFileDialog(
                self, caption, self.output_dir, filters
            )
        else:
            dlg = QtWidgets.QFileDialog(
                self, caption, self.currentPath(), filters
            )
        dlg.setDefaultSuffix(LabelFile.suffix[1:])
        dlg.setAcceptMode(QtWidgets.QFileDialog.AcceptSave)
        dlg.setOption(QtWidgets.QFileDialog.DontConfirmOverwrite, False)
        dlg.setOption(QtWidgets.QFileDialog.DontUseNativeDialog, False)
        basename = os.path.basename(self.current_img)[:-4]
        if self.output_dir:
            default_labelfile_name = osp.join(
                self.output_dir, basename + LabelFile.suffix
            )
        else:
            default_labelfile_name = osp.join(
                self.currentPath(), basename + LabelFile.suffix
            )
        filename = dlg.getSaveFileName(
            self,
            self.tr("Choose File"),
            default_labelfile_name,
            self.tr("Label files (*%s)") % LabelFile.suffix,
        )
        if isinstance(filename, tuple):
            filename, _ = filename
        return filename

    def currentPath(self):
        #return osp.dirname(str(self.filename)) if self.filename else "."
        return "."
    def loadAnno(self,filename):
        '''labelme格式'''
        
        if os.path.exists(filename):
            
            with open(filename,'r') as f:
                label_data = json.load(f)

            flags = label_data["flags"]
            for shape in label_data['shapes']:
                label       = shape['label']
                points      = shape['points']
                shape_type  = shape.get('shape_type', None)
                group_id    = label.split('_')[1]
                label       = label.split('_')[0]
                if not points:
                    # skip point-empty shape
                    continue
                shape = Shape(
                    label=label,
                    shape_type=shape_type,
                    group_id=group_id,
                    flags=flags)
                for x, y in points:
                    shape.addPoint(QtCore.QPointF(x, y))
                shape.close()
                self.addLabel(shape)
            self.canvas.loadShapes([item.shape() for item in self.labelList])


    #change
    def loadAnno2(self, filename):
        
        with open(filename,'r') as f:
            data = json.load(f)
        for shape in data:
            label = shape["label"]
            try:
                ttt = int(label)
                label = self.category_list[ttt]
            except:
                pass
            points = shape["points"]
            shape_type = shape["shape_type"]
            flags = shape["flags"]
            group_id = shape["group_id"]
            if not points:
                # skip point-empty shape
                continue
            shape = Shape(
                label=label,
                shape_type=shape_type,
                group_id=group_id,
                flags=flags
            )
            for x, y in points:
                shape.addPoint(QtCore.QPointF(x, y))
            shape.close()
            self.addLabel(shape)
        self.canvas.loadShapes([item.shape() for item in self.labelList])

    def clickButtonNext(self):
        if self.current_img_index < self.img_len - 1:
            self.current_img_index += 1
            self.current_img = self.img_list[self.current_img_index]
            self.loadImg()

    def clickButtonLast(self):
        if self.current_img_index > 0:
            self.current_img_index -= 1
            self.current_img = self.img_list[self.current_img_index]
            self.loadImg()


    def choose_proposal1(self):
        if len(self.sam_mask_proposal) > 0:
            self.sam_mask = self.sam_mask_proposal[0]
            self.canvas.setHiding()
            self.canvas.update()

    def choose_proposal2(self):
        if len(self.sam_mask_proposal) > 1:
            self.sam_mask = self.sam_mask_proposal[1]
            self.canvas.setHiding()
            self.canvas.update()
            
    def choose_proposal3(self):
        if len(self.sam_mask_proposal) > 2:
            self.sam_mask = self.sam_mask_proposal[2]
            self.canvas.setHiding()
            self.canvas.update()
            
    def choose_proposal4(self):
        if len(self.sam_mask_proposal) > 3:
            self.sam_mask = self.sam_mask_proposal[3]
            self.canvas.setHiding()
            self.canvas.update()
            
    def loadImg(self):
        pixmap = QPixmap(self.current_img)
        self.canvas.loadPixmap(pixmap)
        self.img_progress_bar.setValue(self.current_img_index)
        
        img_name = os.path.basename(self.current_img)[:-4]
        self.current_output_filename = osp.join(self.current_output_dir, img_name + '.json')
        self.labelList.clear()
        if os.path.isfile(self.current_output_filename):
            self.loadAnno(self.current_output_filename)
        self.image_encoded_flag = False

    def clickFileChoose(self):
        directory = QFileDialog.getExistingDirectory(self, 'choose target fold','.')
        if directory == '':
            return
        #self.img_list = glob.glob(directory + '/*.{jpg,png,JPG,PNG}')
        self.img_list = glob.glob(directory + '/*.jpg') + glob.glob(directory + '/*.png')
        self.img_list.sort()
        self.img_len = len(self.img_list)
        if self.img_len == 0:
            return
        self.current_img_index = 0
        self.current_img = self.img_list[self.current_img_index]
        self.img_progress_bar.setMinimum(0)
        self.img_progress_bar.setMaximum(self.img_len-1)
        self.loadImg()

    def clickSaveChoose(self):
        directory = QFileDialog.getExistingDirectory(self, 'choose target fold','.')
        if directory == '':
            return
        else:
            self.current_output_dir = directory
            os.makedirs(self.current_output_dir, exist_ok=True)
            self.loadImg()
            return directory


    def clickSwitchClass(self):
        if self.class_on_flag:
            self.class_on_flag = False
            self.class_on_text.setText('Class Off')
        else:
            self.class_on_flag = True
            self.class_on_text.setText('Class On')


    def clickCategoryChoose(self):
        filename, _ = QFileDialog.getOpenFileName(self, 'choose target file','.')
        try:
            with open(filename, 'r') as f:
                data = f.readlines()
                self.category_list = [i.strip() for i in data]
                self.category_list.sort()
                self.labelDialog = LabelDialog(
                    parent=self,
                    labels=self.category_list,
                    sort_labels=False,
                    show_text_field=True,
                    completion='contains',
                    fit_to_content={'column': True, 'row': False},
                )
        except Exception as e:
            pass

    def clickLoadSAM(self):
        download_model(self.model_type)
        self.sam = sam_model_registry[self.model_type](checkpoint='{}.pth'.format(self.model_type))
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.sam.to(device=self.device)
        self.predictor = SamPredictor(self.sam)
        self.actions.loadSAM.setEnabled(False)
        #self.actions.autoSeg.setEnabled(True)
        self.actions.promptSeg.setEnabled(True)
    
    def clickAutoSeg(self):
        pass
    
    def getMaxId(self):
        max_id = -1
        for label in self.labelList:
            if label.shape().group_id != None:
                max_id = max(max_id, int(label.shape().group_id))
        return max_id
        
    def show_proposals(self, masks=None, flag=1):
        if flag != 1:
            img = cv2.imread(self.current_img)
            if len(img.shape) == 2:
                img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            for msk_idx in range(masks.shape[0]):
                tmp_mask = masks[msk_idx]
                tmp_vis = img.copy()
                tmp_vis[tmp_mask > 0] = 0.5 * tmp_vis[tmp_mask > 0] + 0.5 * np.array([30,30,220])
                tmp_vis = cv2.resize(tmp_vis,(int(0.17 * global_w),int(0.14 * global_h)))
                tmp_vis = tmp_vis.astype(np.uint8)
                pixmap = QPixmap.fromImage(QImage(tmp_vis, tmp_vis.shape[1], tmp_vis.shape[0], tmp_vis.shape[1] * 3 , QImage.Format_RGB888))
                #self.button_proposal_list[msk_idx].setPixmap(pixmap)
                self.button_proposal_list[msk_idx].setIcon(QIcon(pixmap))
                self.button_proposal_list[msk_idx].setIconSize(QSize(tmp_vis.shape[1], tmp_vis.shape[0]))
                self.button_proposal_list[msk_idx].setShortcut(str(msk_idx+1))
        else:
            for idx, button_proposal in enumerate(self.button_proposal_list):
                button_proposal.setText('proprosal{}'.format(idx))
                button_proposal.setIconSize(QSize(0,0))
                self.button_proposal_list[idx].setShortcut(str(idx+1))

    def transform_input(self, image, box=None, points=None):
        if self.keep_input_size == True:
            return image, box, points
        else:
            h,w = image.shape[:2]
            scale_ratio = self.max_size / max(h,w)
            image = cv2.resize(image, (int(w*scale_ratio), int(h*scale_ratio)))
            if box is not None:
                box = box * scale_ratio
            if points is not None:
                points = points * scale_ratio
            return image, box, points
    
    def transform_output(self, masks, size):
        if self.keep_input_size == True:
            return masks
        else:
            h,w = size
            N = masks.shape[0]
            new_masks = np.zeros((N,h,w), dtype=np.uint8)
            for idx in range(N):
                new_masks[idx] = cv2.resize(masks[idx], (w,h))
            return new_masks

    def clickManualSegBBox(self):
        Box = self.canvas.currentBox
        if self.predictor is None or self.current_img == '' or Box == None:
            return
        img = cv2.imread(self.current_img)[:,:,::-1]
        rh, rw = img.shape[:2]
        input_box = np.array([Box[0].x(), Box[0].y(), Box[1].x(), Box[1].y()])
        img, input_box, _ = self.transform_input(img, box=input_box)
        if self.image_encoded_flag == False:
            self.predictor.set_image(img)
            self.image_encoded_flag = True
        masks, iou_prediction, _ = self.predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_box[None, :],
            multimask_output=True,
        )
        masks = self.transform_output(masks.astype(np.uint8), (rh,rw))

        target_idx = np.argmax(iou_prediction)
        self.show_proposals(masks, 0)
        self.sam_mask_proposal = []
        for msk_idx in range(masks.shape[0]):
            mask = masks[msk_idx].astype(np.uint8)
            #CHANGE
            #points_list = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[0]
            contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
            #可能会识别出多个轮廓,挑选最大con_amax的那个
            contours = sorted(contours, key=cv2.contourArea, reverse=True)
            con_amax=contours[0]
            con_amax=self.lesscnt(con_amax)
            points_list =  [con_amax]
            shape_type = 'polygon'
            tmp_sam_mask = []
            for points in points_list:
                area = cv2.contourArea(points)
                if area < 100 and len(points_list) > 1:
                    continue
                pointsx = points[:,0,0]
                pointsy = points[:,0,1]

                shape = Shape(
                    label='stone',
                    shape_type=shape_type,
                    group_id=self.getMaxId() + 1,
                )
                for point_index in range(pointsx.shape[0]):
                    shape.addPoint(QtCore.QPointF(pointsx[point_index], pointsy[point_index]))
                shape.close()
                #self.addLabel(shape)
                tmp_sam_mask.append(shape)
            if msk_idx == target_idx:
                self.sam_mask = tmp_sam_mask
            self.sam_mask_proposal.append(tmp_sam_mask)


    def clickManualSegBox(self):
        ClickPos = self.canvas.currentPos
        ClickNeg = self.canvas.currentNeg
        if self.predictor is None or self.current_img == '' or (ClickPos == None and ClickNeg == None):
            return
        img = cv2.imread(self.current_img)[:,:,::-1]
        rh, rw = img.shape[:2]

        input_clicks = []
        input_types = []
        if ClickPos != None:
            for pos in ClickPos:
                input_clicks.append([int(pos.x()), int(pos.y())])
                input_types.append(1)

        if ClickNeg != None:
            for neg in ClickNeg:
                input_clicks.append([int(neg.x()), int(neg.y())])
                input_types.append(0)
        if len(input_clicks) == 0:
            input_clicks = None
            input_types = None
        else:
            input_clicks = np.array(input_clicks)
            input_types = np.array(input_types)

        img, _, input_clicks = self.transform_input(img, points=input_clicks)

        if self.image_encoded_flag == False:
            self.predictor.set_image(img)
            self.image_encoded_flag = True
        masks, iou_prediction, _ = self.predictor.predict(
            point_coords=input_clicks,
            point_labels=input_types,
            multimask_output=True,
        )
        masks = self.transform_output(masks.astype(np.uint8), (rh,rw))
        
        target_idx = np.argmax(iou_prediction)
        self.show_proposals(masks,0)
        self.sam_mask_proposal = []
        
        for msk_idx in range(masks.shape[0]):
            mask = masks[msk_idx].astype(np.uint8)
            
            #points_list = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[0]
            #max_contour = max(points_list, key=cv2.contourArea)
            #points_list=[self.lesscnt(max_contour)]
            #
            contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
            #可能会识别出多个轮廓,挑选最大con_amax的那个
            contours = sorted(contours, key=cv2.contourArea, reverse=True)
            con_amax=contours[0]
            con_amax=self.lesscnt(con_amax)
            points_list =  [con_amax]
            
            #points_list=contours
            shape_type = 'polygon'
            tmp_sam_mask = []
            for points in points_list:
                area = cv2.contourArea(points)
                if area < 100 and len(points_list) > 1:
                    continue
                pointsx = points[:,0,0]
                pointsy = points[:,0,1]

                shape = Shape(
                    label='stone',
                    shape_type=shape_type,
                    group_id=self.getMaxId() + 1,
                )
                for point_index in range(pointsx.shape[0]):
                    shape.addPoint(QtCore.QPointF(pointsx[point_index], pointsy[point_index]))
                shape.close()
                #self.addLabel(shape)
                tmp_sam_mask.append(shape)
            if msk_idx == target_idx:
                self.sam_mask = tmp_sam_mask
            self.sam_mask_proposal.append(tmp_sam_mask)
            
    def addSamMask(self):
        if len(self.sam_mask) > 0:
            #change
            label = 'stone'
            group_id = self.getMaxId() + 1
            if self.class_on_flag:
                xx = self.labelDialog.popUp(
                    text=label,
                    flags={},
                    group_id=group_id,
                )
                if len(xx) == 4:
                    label, _, group_id,_ = xx
                else:
                    label, _, group_id = xx
            if label == None:
                label = 'Object'
            if type(group_id) != int:
                group_id=self.getMaxId() + 1
            for sam_mask in self.sam_mask:
                sam_mask.label = label
                sam_mask.group_id = group_id
                self.addLabel(sam_mask)
        self.canvas.currentBox = None
        self.canvas.currentPos = None
        self.canvas.currentNeg = None
        self.sam_mask = []
        self.sam_mask_proposal = []
        self.show_proposals()
        self.canvas.loadShapes([item.shape() for item in self.labelList])
        self.actions.save.setEnabled(True)
        self.actions.editMode.setEnabled(True)



    def cleanPrompt(self):
        self.canvas.currentBox = None
        self.canvas.currentPos = None
        self.canvas.currentNeg = None
        self.canvas.current = None
        self.sam_mask = []
        self.sam_mask_proposal = []
        self.show_proposals()
        self.canvas.setHiding()
        self.canvas.update()
        self.actions.editMode.setEnabled(True)



    def zoomRequest(self, delta, pos):
        canvas_width_old = self.canvas.width()
        units = 1.1
        if delta < 0:
            units = 0.9
        self.addZoom(units)

        canvas_width_new = self.canvas.width()
        if canvas_width_old != canvas_width_new:
            canvas_scale_factor = canvas_width_new / canvas_width_old

            x_shift = round(pos.x() * canvas_scale_factor) - pos.x()
            y_shift = round(pos.y() * canvas_scale_factor) - pos.y()

            self.setScroll(
                Qt.Horizontal,
                self.scrollBars[Qt.Horizontal].value() + x_shift,
            )
            self.setScroll(
                Qt.Vertical,
                self.scrollBars[Qt.Vertical].value() + y_shift,
            )

    def scrollRequest(self, delta, orientation):
        units = -delta * 0.1  # natural scroll
        bar = self.scrollBars[orientation]
        value = bar.value() + bar.singleStep() * units
        self.setScroll(orientation, value)

    def newShape(self):
        """Pop-up and give focus to the label editor.

        position MUST be in global coordinates.
        """
        items = self.uniqLabelList.selectedItems()
        text = None
        if items:
            text = items[0].data(Qt.UserRole)
        flags = {}
        group_id = None
        if not text:
            previous_text = self.labelDialog.edit.text()
            xx = self.labelDialog.popUp(text)
            if len(xx) == 4:
                text, flags, group_id, _ = xx
            else:
                text, flags, group_id = xx
            if not text:
                self.labelDialog.edit.setText(previous_text)

        if text and not self.validateLabel(text):
            self.errorMessage(
                self.tr("Invalid label"),
                self.tr("Invalid label '{}' with validation type '{}'").format(
                    text, self._config["validate_label"]
                ),
            )
            text = ""
        if text:
            self.labelList.clearSelection()
            shape = self.canvas.setLastLabel(text, flags)
            shape.group_id = group_id
            self.addLabel(shape)
            self.actions.editMode.setEnabled(True)
            self.actions.undoLastPoint.setEnabled(False)
            self.actions.undo.setEnabled(True)
            self.setDirty()
        else:
            self.canvas.undoLastLine()
            self.canvas.shapesBackups.pop()

    def setDirty(self):
        # Even if we autosave the file, we keep the ability to undo
        self.actions.undo.setEnabled(self.canvas.isShapeRestorable)

        # if self._config["auto_save"] or self.actions.saveAuto.isChecked():
        #     label_file = osp.splitext(self.imagePath)[0] + ".json"
        #     if self.output_dir:
        #         label_file_without_path = osp.basename(label_file)
        #         label_file = osp.join(self.output_dir, label_file_without_path)
        #     self.saveLabels(label_file)
        #     return
        # self.dirty = True
        self.actions.save.setEnabled(True)
        # title = __appname__
        # if self.filename is not None:
        #     title = "{} - {}*".format(title, self.filename)
        # self.setWindowTitle(title)

    # React to canvas signals.
    def shapeSelectionChanged(self, selected_shapes):
        self._noSelectionSlot = True
        for shape in self.canvas.selectedShapes:
            shape.selected = False
        self.labelList.clearSelection()
        self.canvas.selectedShapes = selected_shapes
        for shape in self.canvas.selectedShapes:
            shape.selected = True
            item = self.labelList.findItemByShape(shape)
            self.labelList.selectItem(item)
            self.labelList.scrollToItem(item)
        self._noSelectionSlot = False
        n_selected = len(selected_shapes)
        self.actions.delete.setEnabled(n_selected)
        self.actions.duplicate.setEnabled(n_selected)
        self.actions.edit.setEnabled(n_selected == 1)

    def toggleDrawingSensitive(self, drawing=True):
        """Toggle drawing sensitive.

        In the middle of drawing, toggling between modes should be disabled.
        """
        self.actions.editMode.setEnabled(not drawing)
        # self.actions.undoLastPoint.setEnabled(drawing)
        # self.actions.undo.setEnabled(not drawing)
        # self.actions.delete.setEnabled(not drawing)
    def setScroll(self, orientation, value):
        self.scrollBars[orientation].setValue(int(value))
        self.scroll_values[orientation][self.current_img] = value

    def toolbar(self, title, actions=None):
        toolbar = self.addToolBar("%sToolBar" % title)
        # toolbar.setOrientation(Qt.Vertical)
        if actions:
            utils.addActions(toolbar, actions)
        return toolbar

    def setEditMode(self):
        self.toggleDrawMode(True)

    def toggleDrawMode(self, edit=True, createMode="polygon"):
        self.canvas.setEditing(edit)
        self.canvas.createMode = createMode
        if edit:
            self.actions.createMode.setEnabled(True)
            self.actions.createPointMode.setEnabled(True)
            self.actions.createRectangleMode.setEnabled(True)

        else:
            if createMode == "polygon":
                self.actions.createPointMode.setEnabled(True)
                self.actions.createMode.setEnabled(False)
                self.actions.createRectangleMode.setEnabled(True)

            elif createMode == "point":
                self.actions.createMode.setEnabled(True)
                self.actions.createPointMode.setEnabled(False)
                self.actions.createRectangleMode.setEnabled(True)
            elif createMode == "rectangle":
                self.actions.createMode.setEnabled(True)
                self.actions.createPointMode.setEnabled(True)
                self.actions.createRectangleMode.setEnabled(False)
            else:
                raise ValueError("Unsupported createMode: %s" % createMode)
        self.actions.editMode.setEnabled(not edit)

    def validateLabel(self, label):
        return True

    def labelSelectionChanged(self):
        if self._noSelectionSlot:
            return
        if self.canvas.editing():
            selected_shapes = []
            for item in self.labelList.selectedItems():
                selected_shapes.append(item.shape())
            if selected_shapes:
                self.canvas.selectShapes(selected_shapes)
            else:
                self.canvas.deSelectShape()

    def iou(self, target_mask, mask_list):
        target_mask = target_mask.reshape(1,-1)
        mask_list = mask_list.reshape(mask_list.shape[0], -1)
        i = (target_mask * mask_list)
        u = target_mask + mask_list - i
        return i.sum(1)/u.sum(1)


    def polygon2mask(self,polygon, size):
        mask = np.zeros((size)) # h,w
        contours = np.array(polygon)
        mask = cv2.fillPoly(mask, [contours.astype(np.int32)],1)
        return mask.astype(np.uint8)

    def mask2polygon(self, mask): #转多边形
        contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        #contours=self.lesscnt(contours[0])
        contours = np.array(contours[0])
        #contours = np.array(contours)
        return contours
    def lesscnt(self,cnt):
        '''
        轮廓点简化
        ----------
        输入cnt2=contours[0],cnt格式同输出polyfit''' 

        if len(cnt) >3:
            polyFit = cv2.approxPolyDP(cnt,2, True)
            return polyFit
        else:
            return cnt
    def editLabel(self, item=None):
        if item and not isinstance(item, LabelListWidgetItem):
            raise TypeError("item must be LabelListWidgetItem type")

        if not self.canvas.editing():
            return
        if not item:
            item = self.currentItem()
        if item is None:
            return
        shape = item.shape()
        if shape is None:
            return
        xx = self.labelDialog.popUp(
            text=shape.label,
            flags=shape.flags,
            group_id=shape.group_id,
        )
        if len(xx) == 4:
            text, flags, group_id,_ = xx
        else:
            text, flags, group_id = xx
        if text is None:
            return
        if not self.validateLabel(text):
            self.errorMessage(
                self.tr("Invalid label"),
                self.tr("Invalid label '{}' with validation type '{}'").format(
                    text, self._config["validate_label"]
                ),
            )
            return
        shape.label = text
        shape.flags = flags
        shape.group_id = group_id

        self._update_shape_color(shape)
        if shape.group_id is None:
            item.setText(
                '{} <font color="#{:02x}{:02x}{:02x}">●</font>'.format(
                    html.escape(shape.label), *shape.fill_color.getRgb()[:3]
                )
            )
        else:
            item.setText("({}) {}".format(shape.group_id, shape.label))
        self.setDirty()
        if self.uniqLabelList.findItemByLabel(shape.label) is None:
            item = self.uniqLabelList.createItemFromLabel(shape.label)
            self.uniqLabelList.addItem(item)
            # rgb = self._get_rgb_by_label(shape.label)
            rgb = self._get_rgb_by_label(shape.group_id)
            self.uniqLabelList.setItemLabel(item, shape.label, rgb)

    def labelItemChanged(self, item):
        shape = item.shape()
        self.canvas.setShapeVisible(shape, item.checkState() == Qt.Checked)

    def labelOrderChanged(self):
        self.setDirty()
        self.canvas.loadShapes([item.shape() for item in self.labelList])

    def addLabel(self, shape):
        if shape.group_id is None:
            text = shape.label
        else:
            text = "({}) {}".format(shape.group_id, shape.label)
        label_list_item = LabelListWidgetItem(text, shape)
        self.labelList.addItem(label_list_item)
        if self.uniqLabelList.findItemByLabel(shape.label) is None:
            item = self.uniqLabelList.createItemFromLabel(shape.label)
            self.uniqLabelList.addItem(item)
            # rgb = self._get_rgb_by_label(shape.label)
            rgb = self._get_rgb_by_label(shape.group_id)
            self.uniqLabelList.setItemLabel(item, shape.label, rgb)
        self.labelDialog.addLabelHistory(shape.label)
        for action in self.actions.onShapesPresent:
            action.setEnabled(True)

        self._update_shape_color(shape)
        label_list_item.setText(
            '{} <font color="#{:02x}{:02x}{:02x}">●</font>'.format(
                html.escape(text), *shape.fill_color.getRgb()[:3]
            )
        )
    def _get_rgb_by_label(self, label):
        label = str(label)
        item = self.uniqLabelList.findItemByLabel(label)
        if item is None:
            item = self.uniqLabelList.createItemFromLabel(label)
            self.uniqLabelList.addItem(item)
            rgb = self._get_rgb_by_label(label)
            self.uniqLabelList.setItemLabel(item, label, rgb)
        label_id = self.uniqLabelList.indexFromItem(item).row() + 1
        label_id += 0
        return LABEL_COLORMAP[label_id % len(LABEL_COLORMAP)]

    def togglePolygons(self, value):
        for item in self.labelList:
            item.setCheckState(Qt.Checked if value else Qt.Unchecked)

    def _update_shape_color(self, shape):
        # r, g, b = self._get_rgb_by_label(shape.label)
        r, g, b = self._get_rgb_by_label(shape.group_id)
        shape.line_color = QtGui.QColor(r, g, b)
        shape.vertex_fill_color = QtGui.QColor(r, g, b)
        shape.hvertex_fill_color = QtGui.QColor(255, 255, 255)
        shape.fill_color = QtGui.QColor(r, g, b, 128)
        shape.select_line_color = QtGui.QColor(255, 255, 255)
        shape.select_fill_color = QtGui.QColor(r, g, b, 155)

    def undoShapeEdit(self):
        self.canvas.restoreShape()
        self.labelList.clear()
        self.loadShapes(self.canvas.shapes)
        self.actions.undo.setEnabled(self.canvas.isShapeRestorable)

    def loadShapes(self, shapes, replace=True):
        self._noSelectionSlot = True
        for shape in shapes:
            self.addLabel(shape)
        self.labelList.clearSelection()
        self._noSelectionSlot = False
        self.canvas.loadShapes(shapes, replace=replace)


    def moveShape(self):
        self.canvas.endMove(copy=False)
        self.setDirty()

    def copyShape(self):
        self.canvas.endMove(copy=True)
        for shape in self.canvas.selectedShapes:
            self.addLabel(shape)
        self.labelList.clearSelection()
        self.setDirty()
    def deleteSelectedShape(self):
        #yes, no = QtWidgets.QMessageBox.Yes, QtWidgets.QMessageBox.No
        #msg = self.tr(
        #    "You are about to permanently delete {} polygons, "
        #    "proceed anyway?"
        #).format(len(self.canvas.selectedShapes))
        #if yes == QtWidgets.QMessageBox.warning(
        #    self, self.tr("Attention"), msg, yes | no, yes
        #):
        self.remLabels(self.canvas.deleteSelected())
        self.setDirty()
        if self.noShapes():
            for action in self.actions.onShapesPresent:
                action.setEnabled(False)
    def duplicateSelectedShape(self):
        added_shapes = self.canvas.duplicateSelectedShapes()
        self.labelList.clearSelection()
        for shape in added_shapes:
            self.addLabel(shape)
        self.setDirty()

    def reducePoint(self):
        def format_shape(s):
            data = s.other_data.copy()
            data.update(
                dict(
                    label=s.label.encode("utf-8") if PY2 else s.label,
                    points=[(p.x(), p.y()) for p in s.points],
                    group_id=s.group_id,
                    shape_type=s.shape_type,
                    flags=s.flags,
                    #change
                    line_color = None,
                    fill_color = None
                )
            )
            return data
        def format_shape2(s):
            data=dict(
                group_id=s.group_id,
                label =(s.label.encode("utf-8")+f"_{s.group_id}") if PY2 else (s.label+f"_{s.group_id}"),
                line_color = None,
                fill_color = None,
                points=[(p.x(), p.y()) for p in s.points],
                shape_type=s.shape_type,
                flags={}
            )
            return data
        shapes = self.current_img
        shapes = [format_shape(item.shape()) for item in self.labelList.selectedItems()]
        rm_shapes = [item.shape() for item in self.labelList.selectedItems()]
        self.remLabels(rm_shapes)
        for shape in shapes:
            points = shape['points']
            min_dis = self.get_min_dis(points)
            points_new = [points[0]]
            for i in range(1,len(points)):
                d = math.sqrt((points[i][0] - points_new[-1][0]) ** 2 + (points[i][1] - points_new[-1][1]) ** 2)
                if d > (min_dis * 1.5):
                    points_new.append(points[i])
            shape['points'] = points_new
        #self.labelList.clear()
        for tmp_shape in shapes:
            shape = Shape(
                label=tmp_shape['label'],
                shape_type=tmp_shape['shape_type'],
                group_id=tmp_shape['group_id'],
            )
            for point_index in range(len(tmp_shape['points'])):
                shape.addPoint(QtCore.QPointF(tmp_shape['points'][point_index][0], tmp_shape['points'][point_index][1]))
            shape.close()
            self.addLabel(shape)
            tmp_item = self.labelList.findItemByShape(shape)
            self.labelList.selectItem(tmp_item)
            self.labelList.scrollToItem(tmp_item)
        self.canvas.loadShapes([item.shape() for item in self.labelList])
        self.actions.save.setEnabled(True)

    def get_min_dis(self, points):
        min_dis = 10000
        if len(points) >= 2:
            points_new = [points[0]]
            for i in range(1,len(points)):
                d = math.sqrt((points[i][0] - points_new[-1][0]) ** 2 + (points[i][1] - points_new[-1][1]) ** 2)
                min_dis = min(min_dis, d)
                points_new.append(points[i])
        return min_dis



    def pasteSelectedShape(self):
        self.loadShapes(self._copied_shapes, replace=False)
        self.setDirty()

    def copySelectedShape(self):
        self._copied_shapes = [s.copy() for s in self.canvas.selectedShapes]
        self.actions.paste.setEnabled(len(self._copied_shapes) > 0)

    def currentItem(self):
        items = self.labelList.selectedItems()
        if items:
            return items[0]
        return None

    def remLabels(self, shapes):
        for shape in shapes:
            item = self.labelList.findItemByShape(shape)
            self.labelList.removeItem(item)


    def noShapes(self):
        return not len(self.labelList)

    def addZoom(self, increment=1.1):
        zoom_value = self.zoomWidget.value() * increment
        if increment > 1:
            zoom_value = math.ceil(zoom_value)
        else:
            zoom_value = math.floor(zoom_value)
        self.setZoom(zoom_value)

    def setZoom(self, value):
        self.zoomMode = self.MANUAL_ZOOM
        self.zoomWidget.setValue(value)
        self.zoom_values[self.current_img] = (self.zoomMode, value)

    def paintCanvas(self):
        self.canvas.scale = 0.01 * self.zoomWidget.value()
        self.canvas.adjustSize()
        self.canvas.update()


def get_parser():
    parser = argparse.ArgumentParser(description="pixel annotator by GroundedSAM")
    parser.add_argument(
        "--app_resolution",
        default='1440,2560',
        #default='800,1200',
    )
    parser.add_argument(
        "--model_type",
        default='vit_b',
    )
    parser.add_argument(
        "--keep_input_size",
        type=bool,
        default=True,
    )   
    parser.add_argument(
        "--max_size",
        default=1024,
    )   
    return parser

if __name__ == '__main__':
    parser = get_parser()
    global_h, global_w = [int(i) for i in parser.parse_args().app_resolution.split(',')]
    model_type = parser.parse_args().model_type
    keep_input_size = parser.parse_args().keep_input_size
    max_size = parser.parse_args().max_size
    app = QApplication(sys.argv)
    main = MainWindow(global_h=global_h, global_w=global_w, model_type=model_type, keep_input_size=keep_input_size, max_size=max_size)
    main.show()
    sys.exit(app.exec_())

Box promot not working/no responding on GPU

Not sure if there someone else have the same situation, I click the left top point and try to click the botton right corner to make a rectangle (box), then the program is no responding.

This is run on a Win 10 laptop with 3070 GPU.

Can not load annotation json file from segment-anything!!

I have successful running amg.py file in segment-anything repo and get a json output file.
When I find this annotation tool in segment-anything repo issue, I successful setup the tool.
I choose image folder is working, but when I choose save folder the program will crash like below. (I put the json file from segment-anything repo in save folder, it's COCO RLE format.)
I can't find key name 'shape' in output json file, it like the json file from segment-anything repo can't use in this tool.

image

If I annotate image file with this tool, can I train successfully with output format file?
Or it can't use in segment-anything repo to train, please help me!!

What method can I run the tool with segment-anything repo successfully?

Errors often occur during annotation

Traceback (most recent call last):
File "F:\Project\Segment_Anything\segment-anything-annotator\canvas.py", line 514, in mousePressEvent
self.app.clickManualSegBox()
File "F:\Project\Segment_Anything\segment-anything-annotator\annnotator.py", line 862, in clickManualSegBox
group_id=self.getMaxId() + 1,
File "F:\Project\Segment_Anything\segment-anything-annotator\annnotator.py", line 704, in getMaxId
max_id = max(max_id, label.shape().group_id)
TypeError: '>' not supported between instances of 'NoneType' and 'int'

error

已经安装了segment_anything,报错 no module named 'segment_anything

功能完善建议

体验了一下,先给作者点个赞!
之前一直用的labelme,提几个比较常用的功能需求

  1. json 格式,若能与原始官方格式保持一致,用户量应该不会少,否则考虑到后续的数据处理会劝退很多人
    "version": "5.0.0",
    "flags": {},
    "shapes": [
    {
    "label": "",
    "points":[ ]
    "group_id": ,
    "shape_type": "",
    "flags": {}
    }
    ],
    "imagePath": "",
    "imageData":
    "imageHeight": ,
    "imageWidth":
    }

2.一些快捷键:
ctrl+z撤销上一步操作;
切换上/下张图片(官方快捷键a/d);
切换图片自动保存json;

sudden exit with the typeerror with drawText on m1 mac

Hello,
I have tried several times by your instruction in demo image and each time when I successfully load the SAM, any further step: 'Point Prompt', 'Box Prompt' or 'Manual Polygons' will exit the program once I click on the pic. The error message shows as below:

python annnotator.py --app_resolution 1000,1600 --model_type vit_b --keep_input_size True --max_size 720 2023-04-18 11:57:16.074 python[36573:3837983] +[CATransaction synchronize] called within transaction 2023-04-18 11:57:17.887 python[36573:3837983] +[CATransaction synchronize] called within transaction 2023-04-18 11:57:23.813 python[36573:3837983] +[CATransaction synchronize] called within transaction 2023-04-18 11:57:38.869 python[36573:3837983] +[CATransaction synchronize] called within transaction 2023-04-18 11:58:02.440 python[36573:3837983] +[CATransaction synchronize] called within transaction vit_b model already exists as 'vit_b.pth'. Skipping download. Traceback (most recent call last): File "/Users/peterhuang/Documents/test_code/Python/segment-anything-annotator-master/canvas.py", line 921, in paintEvent self.current.paint(p) File "/Users/peterhuang/Documents/test_code/Python/segment-anything-annotator-master/shape.py", line 213, in paint painter.drawText(cx, cy, self.label) TypeError: arguments did not match any overloaded call: drawText(self, p: Union[QPointF, QPoint], s: str): argument 1 has unexpected type 'float' drawText(self, rectangle: QRectF, flags: int, text: str): argument 1 has unexpected type 'float' drawText(self, rectangle: QRect, flags: int, text: str): argument 1 has unexpected type 'float' drawText(self, rectangle: QRectF, text: str, option: QTextOption = QTextOption()): argument 1 has unexpected type 'float' drawText(self, p: QPoint, s: str): argument 1 has unexpected type 'float' drawText(self, x: int, y: int, width: int, height: int, flags: int, text: str): argument 1 has unexpected type 'float' drawText(self, x: int, y: int, s: str): argument 1 has unexpected type 'float' zsh: abort python annnotator.py --app_resolution 1000,1600 --model_type vit_b True 720

My device is M1 chic Macbook air and python version is 3.10.10. Would you have any idea to solve this problem? Thank you so much.

关于按照requirments.txt安装但是尝试启动标注平台出错

Got keys from plugin meta data ("xcb")
QFactoryLoader::QFactoryLoader() checking directory path "/home/xxx/.conda/envs/sam/bin/platforms" ...
loaded library "/home/xxx/.conda/envs/sam/lib/python3.8/site-packages/cv2/qt/plugins/platforms/libqxcb.so"
QObject::moveToThread: Current thread (0x55c4efa86610) is not the object's thread (0x55c4f292d870).
Cannot move to target thread (0x55c4efa86610)

Buggy visualization of high-res images

Annotator app doesn't load images correctly if their resolution is higher than the app resolution. It shows them cropped and rotated. Would be great if the app resizes the image to the app resolution and converts annotations back to the original image resolution after finding them in the app resolution.

功能完善建议

因为项目需要已经部署标注了一百多种图片,感觉非常棒,觉得还可以补充几个功能:

  1. 和官网的demo相比,还可以增加负样本的点作为输入,这样标注出来的效果会更好,SAM的代码里是集成了正负样本的point prompt,我觉得可以增加一个模式(Negative/Positive Point Prompt)
  2. 在编辑标注的时候,可不可以手动删除某些不需要的点(不过labelme也不支持这个功能)
  3. 可以增加一个窗口,显示所有导入的图片名称和是否标注的状态

Installation guide is incomplete if you are a rookie with github

maybe adding something like this:

To install the segment-anything-annotator repository on your Ubuntu 22.04 PC with Python and PyTorch already installed, follow these steps:
Open a terminal on your Ubuntu PC.
Clone the repository by running the following command:
git clone https://github.com/haochenheheda/segment-anything-annotator.git
This will create a local copy of the repository on your machine.
Navigate to the cloned repository directory:
cd segment-anything-annotator
Install the required dependencies by running:
pip install -r requirements.txt
This command will install all the necessary packages listed in the requirements.txt file.

Maybe something like this in the readme will make the code a bit more accessible.

疑问

可以加载labelme标注好的数据吗

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.