Coder Social home page Coder Social logo

Comments (2)

Sj-Si avatar Sj-Si commented on August 11, 2024

At the bottom of this comment is a diff with the changes needed to add this functionality.

Instructions

For these instructions, you'll either need to be using a linux terminal or use git bash if you're on Windows. Of course, change the file paths to match your environment and OS.

You should have already downloaded the Z3D-E621-Convnext.onnx and file tags-selected.csv from the discord you linked.

Honestly I don't know if the .csv file is even required but I chucked it in there anyway.

  1. Copy the diff from the bottom of this comment to a file my_diff.patch.

    • Name it whatever you want, it doesn't matter.
  2. Navigate to the extension's directory.

cd ~/stable-diffusion-webui/extensions/stable-diffusion-webui-dataset-tag-editor
  1. Apply the patch.
git apply ~/my_diff.patch
  1. Put the two files you downloaded from the e621 discord in your SD-Webui repo's models directory: stable-diffusion-webui/models/TaggerOnnx/Z3D-E621-Convnext/

    • If I remember correctly, the TaggerOnnx directory didn't exist for me so I had to create it. The files must also be located in their own directory (Z3D-E621-Convnext) otherwise it won't work.
    • The directory structure is set in scripts/dataset_tag_editor/tagger.py based on the DEFAULT_ONNX_PATH variable and in the E621 class's self.repo_name variable. Change these if you want to put these files somewhere else.

Important Notes

  • The code is an amalgam of this repo's code along with code from the Tagger extension so don't distribute it or whatever.
  • I feel like I was getting slightly different results from the actual Tagger extension's implementation of the E621 model so I'm not sure what's going on there. Use at your own risk. I suppose just don't run your billion dollar company's website using this E621 tagger code.
  • Don't expect any updates on this code. If this repo updates and breaks this code then you're out of luck. Use this as a jumping off point and modify it further if needed.
diff --git a/scripts/dataset_tag_editor/dte_logic.py b/scripts/dataset_tag_editor/dte_logic.py
index b2d862c..597e349 100644
--- a/scripts/dataset_tag_editor/dte_logic.py
+++ b/scripts/dataset_tag_editor/dte_logic.py
@@ -15,14 +15,14 @@ from scripts.tokenizer import clip_tokenizer
 WD_TAGGER_NAMES = ["wd-v1-4-vit-tagger", "wd-v1-4-convnext-tagger", "wd-v1-4-vit-tagger-v2", "wd-v1-4-convnext-tagger-v2", "wd-v1-4-swinv2-tagger-v2"]
 WD_TAGGER_THRESHOLDS = [0.35, 0.35, 0.3537, 0.3685, 0.3771] # v1: idk if it's okay  v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
 
-INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru()] + [tagger.WaifuDiffusion(name, WD_TAGGER_THRESHOLDS[i]) for i, name in enumerate(WD_TAGGER_NAMES)]
+INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru(), tagger.E621()] + [tagger.WaifuDiffusion(name, WD_TAGGER_THRESHOLDS[i]) for i, name in enumerate(WD_TAGGER_NAMES)]
 INTERROGATOR_NAMES = [it.name() for it in INTERROGATORS]
 
 re_tags = re.compile(r'^([\s\S]+?)( \[\d+\])?$')
 re_newlines = re.compile(r'[\r\n]+')
 
 
-def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshold_wd):
+def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshold_e621, threshold_wd):
     try:
         img = Image.open(path).convert('RGB')
     except:
@@ -33,6 +33,9 @@ def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshol
                 if isinstance(it, tagger.DeepDanbooru):
                     with it as tg:
                         res = tg.predict(img, threshold_booru)
+                elif isinstance(it, tagger.E621):
+                    with it as tg:
+                        res = tg.predict(img, threshold_e621)
                 elif isinstance(it, tagger.WaifuDiffusion):
                     with it as tg:
                         res = tg.predict(img, threshold_wd)
@@ -482,7 +485,22 @@ class DatasetTagEditor(Singleton):
                 print(e)
 
 
-    def load_dataset(self, img_dir:str, caption_ext:str, recursive:bool, load_caption_from_filename:bool, replace_new_line:bool, interrogate_method:InterrogateMethod, interrogator_names:List[str], threshold_booru:float, threshold_waifu:float, use_temp_dir:bool, kohya_json_path:Optional[str], max_res:float):
+    def load_dataset(
+        self,
+        img_dir:str,
+        caption_ext:str,
+        recursive:bool,
+        load_caption_from_filename:bool,
+        replace_new_line:bool,
+        interrogate_method:InterrogateMethod,
+        interrogator_names:List[str],
+        threshold_booru:float,
+        threshold_e621: float,
+        threshold_waifu:float,
+        use_temp_dir:bool,
+        kohya_json_path:Optional[str],
+        max_res:float,
+    ):
         self.clear()
 
         img_dir_obj = Path(img_dir)
@@ -561,6 +579,8 @@ class DatasetTagEditor(Singleton):
                         if isinstance(it, tagger.Tagger):
                             if isinstance(it, tagger.DeepDanbooru):
                                 taggers.append((it, threshold_booru))
+                            if isinstance(it, tagger.E621):
+                                taggers.append((it, threshold_e621))
                             if isinstance(it, tagger.WaifuDiffusion):
                                 taggers.append((it, threshold_waifu))
                         elif isinstance(it, captioning.Captioning):
diff --git a/scripts/dataset_tag_editor/interrogators/__init__.py b/scripts/dataset_tag_editor/interrogators/__init__.py
index 726c896..2c98c03 100644
--- a/scripts/dataset_tag_editor/interrogators/__init__.py
+++ b/scripts/dataset_tag_editor/interrogators/__init__.py
@@ -1,6 +1,7 @@
 from .git_large_captioning import GITLargeCaptioning
 from .waifu_diffusion_tagger import WaifuDiffusionTagger
+from .e621_tagger import E621Tagger
 
 __all__ = [
-    'GITLargeCaptioning', 'WaifuDiffusionTagger' 
+    'GITLargeCaptioning', "E621Tagger", 'WaifuDiffusionTagger' 
 ]
\ No newline at end of file
diff --git a/scripts/dataset_tag_editor/tagger.py b/scripts/dataset_tag_editor/tagger.py
index 5ee520b..c4c5b28 100644
--- a/scripts/dataset_tag_editor/tagger.py
+++ b/scripts/dataset_tag_editor/tagger.py
@@ -5,10 +5,14 @@ import numpy as np
 from typing import Optional, Dict
 from modules import devices, shared
 from modules import deepbooru as db
+from modules import shared
+from modules.shared import models_path
+from pathlib import Path
+import os
 
 from .interrogator import Interrogator
 from .interrogators import WaifuDiffusionTagger
-
+from .interrogators import E621Tagger
 
 class Tagger(Interrogator):
     def start(self):
@@ -23,7 +27,7 @@ class Tagger(Interrogator):
 
 def get_replaced_tag(tag: str):
     use_spaces = shared.opts.deepbooru_use_spaces
-    use_escape = shared.opts.deepbooru_escape   
+    use_escape = shared.opts.deepbooru_escape
     if use_spaces:
         tag = tag.replace('_', ' ')
     if use_escape:
@@ -102,5 +106,41 @@ class WaifuDiffusion(Tagger):
 
         return probability_dict
 
+    def name(self):
+        return self.repo_name
+
+DEFAULT_ONNX_PATH = Path(models_path, "TaggerOnnx")
+
+class E621(Tagger):
+    def __init__(self):
+        self.repo_name = "Z3D-E621-Convnext"
+        self.onnx_path = os.path.join(DEFAULT_ONNX_PATH, self.repo_name)
+        self.tagger_inst = E621Tagger(self.onnx_path)
+        self.threshold = 0.35
+
+    def start(self):
+        self.tagger_inst.load()
+        return self
+
+    def stop(self):
+        self.tagger_inst.unload()
+
+    # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
+    # set threshold<0 to use default value for now...
+    def predict(self, image: Image.Image, threshold: Optional[float] = None):        
+        # may not use ratings
+        # rating = dict(labels[:4])
+        
+        labels = self.tagger_inst.apply(image)
+        
+        if threshold is not None:
+            if threshold < 0:
+                threshold = self.threshold
+            probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:] if x[1] > threshold])
+        else:
+            probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:]])
+
+        return probability_dict
+
     def name(self):
         return self.repo_name
\ No newline at end of file
diff --git a/scripts/main.py b/scripts/main.py
index 33c613b..27c213f 100644
--- a/scripts/main.py
+++ b/scripts/main.py
@@ -29,7 +29,9 @@ GeneralConfig = namedtuple('GeneralConfig', [
     'use_interrogator', 
     'use_interrogator_names',
     'use_custom_threshold_booru', 
-    'custom_threshold_booru', 
+    'custom_threshold_booru',
+    'use_custom_threshold_e621',
+    'custom_threshold_e621',
     'use_custom_threshold_waifu', 
     'custom_threshold_waifu',
     'save_kohya_metadata',
@@ -44,7 +46,7 @@ BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend'
 EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
 MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination'])
 
-CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, False, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False)
+CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, False, 'No', [], False, 0.7, False, 0.35, False, 0.35, False, '', '', True, False, False)
 CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND')
 CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR')
 CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value, 75)
@@ -116,6 +118,7 @@ def read_general_config():
         ('use_blip_to_prefill', 'BLIP'),
         ('use_git_to_prefill', 'GIT-large-COCO'),
         ('use_booru_to_prefill', 'DeepDanbooru'),
+        ('use_e621_to_prefill', 'E621'),
         ('use_waifu_to_prefill', 'wd-v1-4-vit-tagger')
     ]
     use_interrogator_names = []
@@ -240,10 +243,26 @@ def on_ui_tabs():
         # General
 
         components_general = [
-            ui.toprow.cb_backup, ui.load_dataset.tb_img_directory, ui.load_dataset.tb_caption_file_ext, ui.load_dataset.cb_load_recursive,
-            ui.load_dataset.cb_load_caption_from_filename, ui.load_dataset.cb_replace_new_line_with_comma, ui.load_dataset.rb_use_interrogator, ui.load_dataset.dd_intterogator_names,
-            ui.load_dataset.cb_use_custom_threshold_booru, ui.load_dataset.sl_custom_threshold_booru, ui.load_dataset.cb_use_custom_threshold_waifu, ui.load_dataset.sl_custom_threshold_waifu,
-            ui.toprow.cb_save_kohya_metadata, ui.toprow.tb_metadata_output, ui.toprow.tb_metadata_input, ui.toprow.cb_metadata_overwrite, ui.toprow.cb_metadata_as_caption, ui.toprow.cb_metadata_use_fullpath
+            ui.toprow.cb_backup,
+            ui.load_dataset.tb_img_directory,
+            ui.load_dataset.tb_caption_file_ext,
+            ui.load_dataset.cb_load_recursive,
+            ui.load_dataset.cb_load_caption_from_filename,
+            ui.load_dataset.cb_replace_new_line_with_comma,
+            ui.load_dataset.rb_use_interrogator,
+            ui.load_dataset.dd_intterogator_names,
+            ui.load_dataset.cb_use_custom_threshold_booru,
+            ui.load_dataset.sl_custom_threshold_booru,
+            ui.load_dataset.cb_use_custom_threshold_e621,
+            ui.load_dataset.sl_custom_threshold_e621,
+            ui.load_dataset.cb_use_custom_threshold_waifu,
+            ui.load_dataset.sl_custom_threshold_waifu,
+            ui.toprow.cb_save_kohya_metadata,
+            ui.toprow.tb_metadata_output,
+            ui.toprow.tb_metadata_input,
+            ui.toprow.cb_metadata_overwrite,
+            ui.toprow.cb_metadata_as_caption,
+            ui.toprow.cb_metadata_use_fullpath
         ]
         components_filter = \
             [ui.filter_by_tags.tag_filter_ui.cb_prefix, ui.filter_by_tags.tag_filter_ui.cb_suffix, ui.filter_by_tags.tag_filter_ui.cb_regex, ui.filter_by_tags.tag_filter_ui.rb_sort_by, ui.filter_by_tags.tag_filter_ui.rb_sort_order, ui.filter_by_tags.tag_filter_ui.rb_logic] +\
diff --git a/scripts/tag_editor_ui/block_load_dataset.py b/scripts/tag_editor_ui/block_load_dataset.py
index 9437a4d..292d7e6 100644
--- a/scripts/tag_editor_ui/block_load_dataset.py
+++ b/scripts/tag_editor_ui/block_load_dataset.py
@@ -43,6 +43,9 @@ class LoadDatasetUI(UIBase):
                 with gr.Row():
                     self.cb_use_custom_threshold_booru = gr.Checkbox(value=cfg_general.use_custom_threshold_booru, label='Use Custom Threshold (Booru)', interactive=True)
                     self.sl_custom_threshold_booru = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_booru, step=0.01, interactive=True, label='Booru Score Threshold')
+                with gr.Row():
+                    self.cb_use_custom_threshold_e621 = gr.Checkbox(value=cfg_general.use_custom_threshold_e621, label='Use Custom Threshold (E621)', interactive=True)
+                    self.sl_custom_threshold_e621 = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_e621, step=0.01, interactive=True, label='E621 Score Threshold')
                 with gr.Row():
                     self.cb_use_custom_threshold_waifu = gr.Checkbox(value=cfg_general.use_custom_threshold_waifu, label='Use Custom Threshold (WDv1.4 Tagger)', interactive=True)
                     self.sl_custom_threshold_waifu = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_waifu, step=0.01, interactive=True, label='WDv1.4 Tagger Score Threshold')
@@ -58,6 +61,8 @@ class LoadDatasetUI(UIBase):
             use_interrogator_names, #: List[str], : to avoid error on gradio v3.23.0
             use_custom_threshold_booru: bool,
             custom_threshold_booru: float,
+            use_custom_threshold_e621: bool,
+            custom_threshold_e621: float,
             use_custom_threshold_waifu: bool,
             custom_threshold_waifu: float,
             use_kohya_metadata: bool,
@@ -75,9 +80,10 @@ class LoadDatasetUI(UIBase):
                 interrogate_method = InterrogateMethod.APPEND
 
             threshold_booru = custom_threshold_booru if use_custom_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
+            threshold_e621 = custom_threshold_e621 if use_custom_threshold_e621 else -1
             threshold_waifu = custom_threshold_waifu if use_custom_threshold_waifu else -1
 
-            dte_instance.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, replace_new_line, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu, opts.dataset_editor_use_temp_files, kohya_json_path if use_kohya_metadata else None, opts.dataset_editor_max_res)
+            dte_instance.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, replace_new_line, interrogate_method, use_interrogator_names, threshold_booru, threshold_e621, threshold_waifu, opts.dataset_editor_use_temp_files, kohya_json_path if use_kohya_metadata else None, opts.dataset_editor_max_res)
             imgs = dte_instance.get_filtered_imgs(filters=[])
             img_indices = dte_instance.get_filtered_imgindices(filters=[])
             return [
@@ -90,7 +96,23 @@ class LoadDatasetUI(UIBase):
         
         self.btn_load_datasets.click(
             fn=load_files_from_dir,
-            inputs=[self.tb_img_directory, self.tb_caption_file_ext, self.cb_load_recursive, self.cb_load_caption_from_filename, self.cb_replace_new_line_with_comma, self.rb_use_interrogator, self.dd_intterogator_names, self.cb_use_custom_threshold_booru, self.sl_custom_threshold_booru, self.cb_use_custom_threshold_waifu, self.sl_custom_threshold_waifu, toprow.cb_save_kohya_metadata, toprow.tb_metadata_output],
+            inputs=[
+                self.tb_img_directory,
+                self.tb_caption_file_ext,
+                self.cb_load_recursive,
+                self.cb_load_caption_from_filename,
+                self.cb_replace_new_line_with_comma,
+                self.rb_use_interrogator,
+                self.dd_intterogator_names,
+                self.cb_use_custom_threshold_booru,
+                self.sl_custom_threshold_booru,
+                self.cb_use_custom_threshold_e621,
+                self.sl_custom_threshold_e621,
+                self.cb_use_custom_threshold_waifu,
+                self.sl_custom_threshold_waifu,
+                toprow.cb_save_kohya_metadata,
+                toprow.tb_metadata_output,
+            ],
             outputs=
             [dataset_gallery.gl_dataset_images, filter_by_selection.gl_filter_images] +
             [dataset_gallery.cbg_hidden_dataset_filter, dataset_gallery.nb_hidden_dataset_filter_apply] +
diff --git a/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py b/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py
index 47d493d..e57d2cb 100644
--- a/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py
+++ b/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py
@@ -133,16 +133,33 @@ class EditCaptionOfSelectedImageUI(UIBase):
             outputs=[self.tb_edit_caption]
         )
 
-        def interrogate_selected_image(interrogator_name: str, use_threshold_booru: bool, threshold_booru: float, use_threshold_waifu: bool, threshold_waifu: float):
+        def interrogate_selected_image(
+            interrogator_name: str,
+            use_threshold_booru: bool,
+            threshold_booru: float,
+            use_threshold_e621: bool,
+            threshold_e621: float,
+            use_threshold_waifu: bool,
+            threshold_waifu: float,
+        ):
             if not interrogator_name:
                 return ''
             threshold_booru = threshold_booru if use_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
+            threshold_e621 = threshold_e621 if use_threshold_e621 else -1
             threshold_waifu = threshold_waifu if use_threshold_waifu else -1
-            return dte_module.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu)
+            return dte_module.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_e621, threshold_waifu)
 
         self.btn_interrogate_si.click(
             fn=interrogate_selected_image,
-            inputs=[self.dd_intterogator_names_si, load_dataset.cb_use_custom_threshold_booru, load_dataset.sl_custom_threshold_booru, load_dataset.cb_use_custom_threshold_waifu, load_dataset.sl_custom_threshold_waifu],
+            inputs=[
+                self.dd_intterogator_names_si,
+                load_dataset.cb_use_custom_threshold_booru,
+                load_dataset.sl_custom_threshold_booru,
+                load_dataset.cb_use_custom_threshold_e621,
+                load_dataset.sl_custom_threshold_e621,
+                load_dataset.cb_use_custom_threshold_waifu,
+                load_dataset.sl_custom_threshold_waifu,
+            ],
             outputs=[self.tb_interrogate]
         )

from stable-diffusion-webui-dataset-tag-editor.

toshiaki1729 avatar toshiaki1729 commented on August 11, 2024

implemented in #93

from stable-diffusion-webui-dataset-tag-editor.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.