Comments (1)
I have solved this problem. Just change NullInversion class into the following code
class NullInversion:
def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
# this is the denoise process
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
######## with pred mode start #######
pred_type = self.scheduler.config.prediction_type
if pred_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif pred_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(f"Unknown prediction type {pred_type}")
######## with pred mode end #######
# pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
# pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * pred_epsilon
prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
return prev_sample
def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
# this is the inverse process
timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
beta_prod_t = 1 - alpha_prod_t
########### daal with pred mode start ###########
pred_type = self.scheduler.config.prediction_type
# print(pred_type)
# breakpoint()
if pred_type == "epsilon":
next_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif pred_type == "v_prediction":
next_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(f"Unknown prediction type {pred_type}")
########### daal with pred mode end ###########
# next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * pred_epsilon
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
return next_sample
def get_noise_pred_single(self, latents, t, context):
noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
return noise_pred
def get_noise_pred(self, latents, t, is_forward=True, context=None):
latents_input = torch.cat([latents] * 2)
if context is None:
context = self.context
guidance_scale = 1 if is_forward else GUIDANCE_SCALE
noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
if is_forward:
latents = self.next_step(noise_pred, t, latents)
else:
latents = self.prev_step(noise_pred, t, latents)
return latents
@torch.no_grad()
def latent2image(self, latents, return_type='np'):
latents = 1 / 0.18215 * latents.detach()
image = self.model.vae.decode(latents)['sample']
if return_type == 'np':
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = (image * 255).astype(np.uint8)
return image
@torch.no_grad()
def image2latent(self, image):
with torch.no_grad():
if type(image) is Image:
image = np.array(image)
if type(image) is torch.Tensor and image.dim() == 4:
latents = image
else:
image = torch.from_numpy(image).float() / 127.5 - 1
image = image.permute(2, 0, 1).unsqueeze(0).to(device)
latents = self.model.vae.encode(image)['latent_dist'].mean
latents = latents * 0.18215
return latents
@torch.no_grad()
def init_prompt(self, prompt: str):
uncond_input = self.model.tokenizer(
[""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
return_tensors="pt"
)
uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
text_input = self.model.tokenizer(
[prompt],
padding="max_length",
max_length=self.model.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
self.context = torch.cat([uncond_embeddings, text_embeddings])
self.prompt = prompt
@torch.no_grad()
def ddim_loop(self, latent):
uncond_embeddings, cond_embeddings = self.context.chunk(2)
all_latent = [latent]
# from clear img to noise
latent = latent.clone().detach()
for i in range(NUM_DDIM_STEPS):
t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
latent = self.next_step(noise_pred, t, latent)
all_latent.append(latent)
return all_latent
@property
def scheduler(self):
return self.model.scheduler
@torch.no_grad()
def ddim_inversion(self, image):
latent = self.image2latent(image)
image_rec = self.latent2image(latent)
ddim_latents = self.ddim_loop(latent)
return image_rec, ddim_latents
def null_optimization(self, latents, num_inner_steps, epsilon):
uncond_embeddings, cond_embeddings = self.context.chunk(2)
uncond_embeddings_list = []
latent_cur = latents[-1]
bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS)
for i in range(NUM_DDIM_STEPS):
uncond_embeddings = uncond_embeddings.clone().detach()
uncond_embeddings.requires_grad = True
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
latent_prev = latents[len(latents) - i - 2]
t = self.model.scheduler.timesteps[i]
with torch.no_grad():
noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
for j in range(num_inner_steps):
noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond)
latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
loss = nnf.mse_loss(latents_prev_rec, latent_prev)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_item = loss.item()
bar.update()
if loss_item < epsilon + i * 2e-5:
break
for j in range(j + 1, num_inner_steps):
bar.update()
uncond_embeddings_list.append(uncond_embeddings[:1].detach())
with torch.no_grad():
context = torch.cat([uncond_embeddings, cond_embeddings])
latent_cur = self.get_noise_pred(latent_cur, t, False, context)
bar.close()
return uncond_embeddings_list
def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
self.init_prompt(prompt)
# ptp_utils.register_attention_control(self.model, None)
image_gt = load_512(image_path, *offsets)
if verbose:
print("DDIM inversion...")
image_rec, ddim_latents = self.ddim_inversion(image_gt)
if verbose:
print("Null-text optimization...")
uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon)
return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
def __init__(self, model):
# scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
# set_alpha_to_one=False)
self.model = model
self.tokenizer = self.model.tokenizer
self.model.scheduler.set_timesteps(NUM_DDIM_STEPS)
self.prompt = None
self.context = None
from prompt-to-prompt.
Related Issues (20)
- Can't load a fine-tuned model due to using an old version of diffusers
- On the non-convergence of null text
- Is training code available? HOT 1
- Does it support SDXL? HOT 2
- When I use Diffusers==0.21.0, p2p seems to generate the same as w/o p2p. Code needs to be upgraded.
- bad results when changing the clothes's color HOT 1
- code for user-defined mask HOT 3
- If I want to add null text inversion to the training process to maintain the feature of the edited image, how can I achieve this?
- code about Delta Denoising Score
- Can we give image as input? HOT 1
- wrong DDIM inversion step HOT 1
- The setting of DDIMScheduler.
- Why the image in the article is so nice and true-life however I get low-grade images on myself?
- null text inversion sdxl support require HOT 2
- DDS_zeroshot.ipynb - sds loss derivation HOT 2
- how to run! who can give me a detail environment requirement, such as the version of diffusers and transformer? HOT 2
- visualizing self attention map HOT 1
- Error introduced when using p2p pipeline comparing to null-text inversion HOT 1
- Can run in windows environment?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from prompt-to-prompt.