Coder Social home page Coder Social logo

Comments (9)

yue-zhongqi avatar yue-zhongqi commented on May 17, 2024 9

I manage to find a work around using diffusers==0.14.0. Reason of failure: in the new diffusers library, the forward function signature of CrossAttention class changes. In the old 0.3.0 version, the default forward function does not leverage the mask input (requried by Prompt2Prompt), hence Prompt2Prompt modify the forward function in ptp_utils.py with the register_attention_control function. Yet this modified forward conflicts with the updated forward signature in the new diffusion library (0.14.0).

  • In ptp_utils.py, in def register_attention_control(model, controller):, change def forward(x, context=None, mask=None): to def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):. Then add x = hidden_states, context = encoder_hidden_states and mask = attention_mask inside the function.
  • In diffusers.models.cross_attention.py, add the following functions in CrossAttention class. This is because the two functions are called inside the registered forward function, yet they are removed from the CrossAttention class in the newer diffusers library.
    def reshape_heads_to_batch_dim(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.heads
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
        return tensor

    def reshape_batch_dim_to_heads(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.heads
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
        return tensor
  • Alternatively, one can define the two functions in ptp_utils.py and modify the forward function defined in register_attention_control to call them instead.

from prompt-to-prompt.

nihaomiao avatar nihaomiao commented on May 17, 2024 4

The suggestions from @yue-zhongqi work. But instead of directly changing the CrossAttention Class in the diffusers.models.cross_attention.py, one can also reuse the official functions head_to_batch_dim and batch_to_head_dim in the newer version of diffuser to replace reshape_heads_to_batch_dim and reshape_batch_dim_to_heads. In short, one can directly replace the original def forward(x, context=None, mask=None) function in def register_attention_control(model, controller) of ptp_utils.py with the following codes:

  def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
      x = hidden_states
      context = encoder_hidden_states
      mask = attention_mask
      batch_size, sequence_length, dim = x.shape
      h = self.heads
      q = self.to_q(x)
      is_cross = context is not None
      context = context if is_cross else x
      k = self.to_k(context)
      v = self.to_v(context)
      q = self.head_to_batch_dim(q)
      k = self.head_to_batch_dim(k)
      v = self.head_to_batch_dim(v)

      sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

      if mask is not None:
          mask = mask.reshape(batch_size, -1)
          max_neg_value = -torch.finfo(sim.dtype).max
          mask = mask[:, None, :].repeat(h, 1, 1)
          sim.masked_fill_(~mask, max_neg_value)

      # attention, what we cannot get enough of
      attn = sim.softmax(dim=-1)
      attn = controller(attn, is_cross, place_in_unet)
      out = torch.einsum("b i j, b j d -> b i d", attn, v)
      out = self.batch_to_head_dim(out)
      return to_out(out)

from prompt-to-prompt.

HiddenGalaxy avatar HiddenGalaxy commented on May 17, 2024 1

Did you solve the problem, I had the same problem but didn't know how to solve it

from prompt-to-prompt.

patrickvonplaten avatar patrickvonplaten commented on May 17, 2024

Here ๐Ÿ‘‹

Maintainer of the diffusers library here - should we try to add a prompt-to-prompt pipeline to diffusers to make sure things are actively maintained?

from prompt-to-prompt.

aliasgharkhani avatar aliasgharkhani commented on May 17, 2024

@HosamGen This error happens because you are using a newer version of diffusers library. If you downgrade to diffusers==0.3.0 it should solve the problem. But when I downgrade, and run this line: ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=MY_TOKEN, scheduler=scheduler).to(device) it does not proceed and gets stock there.

from prompt-to-prompt.

HosamGen avatar HosamGen commented on May 17, 2024

@aliasgharkhani For the latent model notebook this would fix it, but for the stable model this does not proceed as I face the same issue you mentioned.

from prompt-to-prompt.

HosamGen avatar HosamGen commented on May 17, 2024

@patrickvonplaten
Sure, that would be great.

from prompt-to-prompt.

HosamGen avatar HosamGen commented on May 17, 2024

@yue-zhongqi This works, thank you very much.

from prompt-to-prompt.

momoshenchi avatar momoshenchi commented on May 17, 2024

that's works

from prompt-to-prompt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.