Coder Social home page Coder Social logo

Comments (6)

ExponentialML avatar ExponentialML commented on July 23, 2024 5

Okay, so I came up with a solution to train multiple subjects into one model. I had an idea where you assign a placeholder token to a specific image, and it only trains that specific token against said image set. You no longer need to change the hard coded token with this method as it's in the filename.

For ease of use, you append a placeholder to the beginning of the filename, then parse it during training. Ideally you would want to create a dict or a custom dataloader, but this solution works.

An example is: dog_1.png , dog_2.png, cat_1.png for each different instance of thing you want to train in your training folder.

This all works in the personalized.py script. The token before _ in the filename gets parsed, then trains each image and parsed filename token against the class that you have set in the training parameter (eg: "toy").

This is setup for just the placeholder token, but can be applied to classes using this same method, and you can have multiple classes("animal_", "car_", "person_") and tokens in one go. I haven't tested this yet, but there's no reason why it shouldn't work if you wish to implement it.

You also need to increase the epochs depending on how many images you have in your dataset. In my tests, I just used three images per instance of token on what I wanted to train. This also increases training time, so if it takes 15 minutes and 5 images for a good finetune, a basic measure would be multiplying this based on the amount of image sets you have.

Here's the script below that gets this working. All you need to do is replace this personalized.py with yours, do the instructions above, and it should work. I've also renamed the templates to a {} instead of photo of a {} as it seems to give me good results, but feel free to change it back.

import os
import numpy as np
import PIL
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

import random

training_templates_smallest = [
    'a {} {}',
]

reg_templates_smallest = [
    'a {}',
]

class PersonalizedBase(Dataset):
    def __init__(self,
                 data_root,
                 size=None,
                 repeats=100,
                 interpolation="bicubic",
                 flip_p=0.5,
                 set="train",
                 placeholder_token="dog",
                 per_image_tokens=False,
                 center_crop=False,
                 mixing_prob=0.25,
                 coarse_class_text=None,
                 reg = False
                 ):

        self.data_root = data_root

        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]

        # self._length = len(self.image_paths)
        self.num_images = len(self.image_paths)
        self._length = self.num_images 

        self.placeholder_token = placeholder_token

        self.per_image_tokens = per_image_tokens
        self.center_crop = center_crop
        self.mixing_prob = mixing_prob

        self.coarse_class_text = coarse_class_text

        if per_image_tokens:
            assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."

        if set == "train":
            self._length = self.num_images * repeats

        self.size = size
        self.interpolation = {"linear": PIL.Image.LINEAR,
                              "bilinear": PIL.Image.BILINEAR,
                              "bicubic": PIL.Image.BICUBIC,
                              "lanczos": PIL.Image.LANCZOS,
                              }[interpolation]
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)
        self.reg = reg

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        example = {}
        image = Image.open(self.image_paths[i % self.num_images])

        if not image.mode == "RGB":
            image = image.convert("RGB")
            
        placeholder_string = self.placeholder_token
        pathname = Path(self.image_paths[i % self.num_images]).name
        placeholder_token = pathname.split("_")[0]
        
        if self.coarse_class_text:
            placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
        
        if not self.reg:
            #You only need to use 1 template for Dreambooth, but you can try more if you wished (not recommended)
            text = random.choice(training_templates_smallest).format(placeholder_token, placeholder_string)
        else:
            text = random.choice(reg_templates_smallest).format(placeholder_string)
            
        example["caption"] = text

        # default to score-sde preprocessing
        img = np.array(image).astype(np.uint8)
        
        if self.center_crop:
            crop = min(img.shape[0], img.shape[1])
            h, w, = img.shape[0], img.shape[1]
            img = img[(h - crop) // 2:(h + crop) // 2,
                (w - crop) // 2:(w + crop) // 2]

        image = Image.fromarray(img)
        if self.size is not None:
            image = image.resize((self.size, self.size), resample=self.interpolation)

        image = self.flip(image)
        image = np.array(image).astype(np.uint8)
        example["image"] = (image / 127.5 - 1.0).astype(np.float32)
        return example

from dreambooth-stable-diffusion.

eadnams22 avatar eadnams22 commented on July 23, 2024 2

@ExponentialML How do you do multiple classes for regularizing these diverse classes as well? Like say I have "Woman" "Man" "Cat" "Dog" pools of regularization images, then have the new training images "Jane" "John" "Buddy" and "Ollie" to train into each of those classes.

How do I match them to their respective classes in the same model training session?

from dreambooth-stable-diffusion.

ExponentialML avatar ExponentialML commented on July 23, 2024 1

When training the second person did you change the name of the hardcoded token "sks"?

https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/ldm/data/personalized.py#L10

This happens even when changing the hardcoded token. There's a problem somewhere that's pushing all embeddings to the same space during training. I thought it was the global seed, but that isn't the case.

from dreambooth-stable-diffusion.

binaryninja avatar binaryninja commented on July 23, 2024

When training the second person did you change the name of the hardcoded token "sks"?

https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/ldm/data/personalized.py#L10

from dreambooth-stable-diffusion.

JoePenna avatar JoePenna commented on July 23, 2024

@binaryninja @ExponentialML absolutely, I changed every single thing I could. Including the training prompts.

from dreambooth-stable-diffusion.

ExponentialML avatar ExponentialML commented on July 23, 2024

@binaryninja @ExponentialML absolutely, I changed every single thing I could. Including the training prompts.

I haven't tried this as I'm a bit busy this week, but this could be a possible solution.

After training, prune the model checkpoints from 11GB to 2GB using this script. This can be done before or after training presumably:
https://github.com/harubaru/waifu-diffusion/blob/main/scripts/prune.py

Then, merge all of the model checkpoints as stated in this repository:
https://github.com/Jack000/glid-3-xl-stable#trainingfine-tuning

I don't know what effect it would have, but this might be a step in the right direction if there aren't any viable fixes to the problem.

EDIT:

Tried and doesn't work. I've tried concatenating the tensors, but no dice.

from dreambooth-stable-diffusion.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.