Comments (6)
Okay, so I came up with a solution to train multiple subjects into one model. I had an idea where you assign a placeholder token to a specific image, and it only trains that specific token against said image set. You no longer need to change the hard coded token with this method as it's in the filename.
For ease of use, you append a placeholder to the beginning of the filename, then parse it during training. Ideally you would want to create a dict or a custom dataloader, but this solution works.
An example is: dog_1.png
, dog_2.png
, cat_1.png
for each different instance of thing you want to train in your training folder.
This all works in the personalized.py
script. The token before _
in the filename gets parsed, then trains each image and parsed filename token against the class that you have set in the training parameter (eg: "toy").
This is setup for just the placeholder token, but can be applied to classes using this same method, and you can have multiple classes("animal_", "car_", "person_"
) and tokens in one go. I haven't tested this yet, but there's no reason why it shouldn't work if you wish to implement it.
You also need to increase the epochs depending on how many images you have in your dataset. In my tests, I just used three images per instance of token on what I wanted to train. This also increases training time, so if it takes 15 minutes and 5 images for a good finetune, a basic measure would be multiplying this based on the amount of image sets you have.
Here's the script below that gets this working. All you need to do is replace this personalized.py
with yours, do the instructions above, and it should work. I've also renamed the templates to a {}
instead of photo of a {}
as it seems to give me good results, but feel free to change it back.
import os
import numpy as np
import PIL
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
training_templates_smallest = [
'a {} {}',
]
reg_templates_smallest = [
'a {}',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="dog",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
reg = False
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
self.mixing_prob = mixing_prob
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.reg = reg
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
placeholder_string = self.placeholder_token
pathname = Path(self.image_paths[i % self.num_images]).name
placeholder_token = pathname.split("_")[0]
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
if not self.reg:
#You only need to use 1 template for Dreambooth, but you can try more if you wished (not recommended)
text = random.choice(training_templates_smallest).format(placeholder_token, placeholder_string)
else:
text = random.choice(reg_templates_smallest).format(placeholder_string)
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
from dreambooth-stable-diffusion.
@ExponentialML How do you do multiple classes for regularizing these diverse classes as well? Like say I have "Woman" "Man" "Cat" "Dog" pools of regularization images, then have the new training images "Jane" "John" "Buddy" and "Ollie" to train into each of those classes.
How do I match them to their respective classes in the same model training session?
from dreambooth-stable-diffusion.
When training the second person did you change the name of the hardcoded token "sks"?
https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/ldm/data/personalized.py#L10
This happens even when changing the hardcoded token. There's a problem somewhere that's pushing all embeddings to the same space during training. I thought it was the global seed, but that isn't the case.
from dreambooth-stable-diffusion.
When training the second person did you change the name of the hardcoded token "sks"?
https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/ldm/data/personalized.py#L10
from dreambooth-stable-diffusion.
@binaryninja @ExponentialML absolutely, I changed every single thing I could. Including the training prompts.
from dreambooth-stable-diffusion.
@binaryninja @ExponentialML absolutely, I changed every single thing I could. Including the training prompts.
I haven't tried this as I'm a bit busy this week, but this could be a possible solution.
After training, prune the model checkpoints from 11GB to 2GB using this script. This can be done before or after training presumably:
https://github.com/harubaru/waifu-diffusion/blob/main/scripts/prune.py
Then, merge all of the model checkpoints as stated in this repository:
https://github.com/Jack000/glid-3-xl-stable#trainingfine-tuning
I don't know what effect it would have, but this might be a step in the right direction if there aren't any viable fixes to the problem.
EDIT:
Tried and doesn't work. I've tried concatenating the tensors, but no dice.
from dreambooth-stable-diffusion.
Related Issues (20)
- %pip install -q accelerate transformers ftfy bitsandbytes==0.35.0 gradio natsort safetensors xformers
- Exception training model: 'module 'torch' has no attribute '_dynamo''.
- Interface changed for add_argparse_args() of lightning.Trainer HOT 1
- RuntimeError HOT 4
- AttributeError: module 'torch.linalg' has no attribute 'solve'
- Is there any method for loop t-step denoising to restore images and parallel speed up in stable diffusion?
- .
- This repo has many problem on windows
- cuda out of memory on RTX 24gb 3090 HOT 4
- ERROR: Failed building wheel for dlib
- Nothing Habben when Traning
- How to use DreamBooth for unconditional image synthesis.
- Questions about parameters
- ERROR: huggingface_hub.utils._validators.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name':
- Implementation of metrics in the Dreambooth paper
- RuntimeError: Error(s) in loading state_dict for LatentDiffusion: size mismatch
- Unable to train Dreambooth on Mac M1
- Dreambooth training with image captions HOT 1
- Size of the trained checkpoint (ckpt) file HOT 1
- Support for inpainting training for dreambooth?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dreambooth-stable-diffusion.