Coder Social home page Coder Social logo

Comments (5)

Georgefwt avatar Georgefwt commented on September 26, 2024

I have encountered the same problem. Does the author have any instructions on how to handle situations when the guidance_scale is greater than 1?

from renoise-inversion.

johndpope avatar johndpope commented on September 26, 2024

I jam the renoise_inversion.py + the above error into claude - it spits out this answer
Screenshot 2024-04-28 at 7 19 07 am

_The error you encountered is likely due to a mismatch in the tensor sizes when using a non-zero guidance scale during the inversion process. In the provided code, the guidance scale is set to 0 during inversion and 1.0 during inference for specific reasons.
Using a guidance scale greater than 1 during inversion can be problematic because it involves splitting the batch size and performing classifier-free guidance. This requires careful handling of tensor sizes throughout the pipeline to ensure consistency.
The reason for using a guidance scale of 0 during inversion is to focus solely on reconstructing the input image without any additional conditioning or guidance. The goal is to find the latent representation that closely matches the input image, and introducing guidance during this process may interfere with the reconstruction quality.
On the other hand, during inference, a guidance scale of 1.0 is used to apply a small amount of conditioning to the generated image. This allows for some control over the generated output while still maintaining a close resemblance to the original image.
To use a higher guidance scale (e.g., 7) during inversion, you would need to modify the code to handle the tensor sizes correctly throughout the pipeline. This would involve ensuring that the tensors are split and concatenated appropriately to match the expected dimensions at each step.
Here are a few suggestions to address the issue:

Modify the unet_pass function to handle the case when pipe.do_classifier_free_guidance is True during inversion. Ensure that the latent model input and other tensors are split and concatenated correctly.
Update the inversion_step function to handle the splitting and concatenation of tensors when using classifier-free guidance during inversion. Make sure that the tensor sizes match at each step, especially when performing noise regularization.
Adjust the tensor sizes in the patchify_latents_kl_divergence and latents_kl_divergence functions to handle the case when the batch size is doubled due to classifier-free guidance.

It's important to note that using a high guidance scale during inversion may impact the quality of the reconstructed image, as it introduces additional conditioning that may deviate from the original input. Experimentation and fine-tuning may be necessary to find the right balance between reconstruction quality and guidance strength.
If you choose to modify the code to support a higher guidance scale during inversion, be prepared to carefully debug and ensure that the tensor sizes and shapes are consistent throughout the pipeline to avoid any runtime errors._

UPDATE
I then ask it to proviide code
Screenshot 2024-04-28 at 7 24 12 am

def noise_regularization(
    e_t, noise_pred_optimal, lambda_kl, lambda_ac, num_reg_steps, num_ac_rolls, generator=None, do_classifier_free_guidance=False
):
    for _outer in range(num_reg_steps):
        if lambda_kl > 0:
            _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
            if do_classifier_free_guidance:
                _var = torch.cat([_var] * 2)
                noise_pred_optimal = torch.cat([noise_pred_optimal] * 2)
            l_kld = patchify_latents_kl_divergence(_var, noise_pred_optimal)
            l_kld.backward()
            _grad = _var.grad.detach()
            _grad = torch.clip(_grad, -100, 100)
            if do_classifier_free_guidance:
                _grad = _grad.chunk(2)[0]
            e_t = e_t - lambda_kl * _grad
        if lambda_ac > 0:
            for _inner in range(num_ac_rolls):
                _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
                if do_classifier_free_guidance:
                    _var = torch.cat([_var] * 2)
                l_ac = auto_corr_loss(_var, generator=generator)
                l_ac.backward()
                _grad = _var.grad.detach() / num_ac_rolls
                if do_classifier_free_guidance:
                    _grad = _grad.chunk(2)[0]
                e_t = e_t - lambda_ac * _grad
        e_t = e_t.detach()

    return e_t


def inversion_step(
    pipe,
    z_t: torch.tensor,
    t: torch.tensor,
    prompt_embeds,
    added_cond_kwargs,
    num_renoise_steps: int = 100,
    first_step_max_timestep: int = 250,
    generator=None,
) -> torch.tensor:
    extra_step_kwargs = {}
    avg_range = pipe.cfg.average_first_step_range if t.item() < first_step_max_timestep else pipe.cfg.average_step_range
    num_renoise_steps = min(pipe.cfg.max_num_renoise_steps_first_step, num_renoise_steps) if t.item() < first_step_max_timestep else num_renoise_steps

    nosie_pred_avg = None
    noise_pred_optimal = None
    z_tp1_forward = pipe.scheduler.add_noise(pipe.z_0, pipe.noise, t.view((1))).detach()

    approximated_z_tp1 = z_t.clone()
    for i in range(num_renoise_steps + 1):

        with torch.no_grad():
            # if noise regularization is enabled, we need to double the batch size for the first step
            if pipe.cfg.noise_regularization_num_reg_steps > 0 and i == 0:
                approximated_z_tp1 = torch.cat([z_tp1_forward, approximated_z_tp1])
                prompt_embeds_in = torch.cat([prompt_embeds, prompt_embeds])
                if added_cond_kwargs is not None:
                    added_cond_kwargs_in = {}
                    added_cond_kwargs_in['text_embeds'] = torch.cat([added_cond_kwargs['text_embeds'], added_cond_kwargs['text_embeds']])
                    added_cond_kwargs_in['time_ids'] = torch.cat([added_cond_kwargs['time_ids'], added_cond_kwargs['time_ids']])
                else:
                    added_cond_kwargs_in = None
            else:
                prompt_embeds_in = prompt_embeds
                added_cond_kwargs_in = added_cond_kwargs

            noise_pred = unet_pass(pipe, approximated_z_tp1, t, prompt_embeds_in, added_cond_kwargs_in)

            # if noise regularization is enabled, we need to split the batch size for the first step
            if pipe.cfg.noise_regularization_num_reg_steps > 0 and i == 0:
                noise_pred_optimal, noise_pred = noise_pred.chunk(2)
                noise_pred_optimal = noise_pred_optimal.detach()

            # perform guidance
            if pipe.do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + pipe.guidance_scale * (noise_pred_text - noise_pred_uncond)

            # Calculate average noise
            if  i >= avg_range[0] and i < avg_range[1]:
                j = i - avg_range[0]
                if nosie_pred_avg is None:
                    nosie_pred_avg = noise_pred.clone()
                else:
                    nosie_pred_avg = j * nosie_pred_avg / (j + 1) + noise_pred / (j + 1)

        if i >= avg_range[0] or (not pipe.cfg.average_latent_estimations and i > 0):
            noise_pred = noise_regularization(noise_pred, noise_pred_optimal, lambda_kl=pipe.cfg.noise_regularization_lambda_kl, lambda_ac=pipe.cfg.noise_regularization_lambda_ac, num_reg_steps=pipe.cfg.noise_regularization_num_reg_steps, num_ac_rolls=pipe.cfg.noise_regularization_num_ac_rolls, generator=generator, do_classifier_free_guidance=pipe.do_classifier_free_guidance)
        
        approximated_z_tp1 = pipe.scheduler.inv_step(noise_pred, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()

    # if average latents is enabled, we need to perform an additional step with the average noise
    if pipe.cfg.average_latent_estimations and nosie_pred_avg is not None:
        nosie_pred_avg = noise_regularization(nosie_pred_avg, noise_pred_optimal, lambda_kl=pipe.cfg.noise_regularization_lambda_kl, lambda_ac=pipe.cfg.noise_regularization_lambda_ac, num_reg_steps=pipe.cfg.noise_regularization_num_reg_steps, num_ac_rolls=pipe.cfg.noise_regularization_num_ac_rolls, generator=generator, do_classifier_free_guidance=pipe.do_classifier_free_guidance)
        approximated_z_tp1 = pipe.scheduler.inv_step(nosie_pred_avg, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()

    # perform noise correction
    if pipe.cfg.perform_noise_correction:
        noise_pred = unet_pass(pipe, approximated_z_tp1, t, prompt_embeds, added_cond_kwargs)

        # perform guidance
        if pipe.do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + pipe.guidance_scale * (noise_pred_text - noise_pred_uncond)
        
        pipe.scheduler.step_and_update_noise(noise_pred, t, approximated_z_tp1, z_t, return_dict=False, optimize_epsilon_type=pipe.cfg.perform_noise_correction)

    return approximated_z_tp1

from renoise-inversion.

Georgefwt avatar Georgefwt commented on September 26, 2024

I tested the code generated by Claude and found it cannot be used directly. In fact, I need to make the following modifications to the noise_regularization function:

def noise_regularization(
    e_t, noise_pred_optimal, lambda_kl, lambda_ac, num_reg_steps, num_ac_rolls, generator=None, do_classifier_free_guidance=False
):
    for _outer in range(num_reg_steps):
        if lambda_kl > 0:
            _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
            if do_classifier_free_guidance:
                __var = torch.cat([_var] * 2)
            l_kld = patchify_latents_kl_divergence(__var, noise_pred_optimal)
            l_kld.backward()
            _grad = _var.grad.detach()
            _grad = torch.clip(_grad, -100, 100)
            if do_classifier_free_guidance:
                _grad = _grad.chunk(2)[0]
            e_t = e_t - lambda_kl * _grad
        if lambda_ac > 0:
            for _inner in range(num_ac_rolls):
                _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
                l_ac = auto_corr_loss(_var, generator=generator)
                l_ac.backward()
                _grad = _var.grad.detach() / num_ac_rolls
                if do_classifier_free_guidance:
                    _grad = _grad.chunk(2)[0]
                e_t = e_t - lambda_ac * _grad
        e_t = e_t.detach()

    return e_t

Now the inversion can run properly, but the editability is still poor (Unable to change the lion to a tiger):

from renoise-inversion.

garibida avatar garibida commented on September 26, 2024

Hi,

First, I uploaded a commit that fixed the errors when using CFG greater than 1.0.
Second, regarding the question about using CFG=0.0 in the config and CFG=1.0 during inference: if the CFG value is less than or equal to 1.0, the diffusion pipeline does not perform CFG at all, so it doesn't really matter.

from renoise-inversion.

Georgefwt avatar Georgefwt commented on September 26, 2024

I'm sorry, I tested the code and set guidance_scale: float = 4.0 in src/config.py. This is the reconstruction result I got. Is there something I did wrong?

from renoise-inversion.

Related Issues (9)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.