Coder Social home page Coder Social logo

Colab notebook about glide-text2im HOT 6 CLOSED

openai avatar openai commented on July 16, 2024 13
Colab notebook

from glide-text2im.

Comments (6)

woctezuma avatar woctezuma commented on July 16, 2024 7

I cannot reproduce the same results as showed in paper images "Figure 1" with the same text prompt.

Unfortunately, this is normal, because the publicly available model:

  • is smaller, as it has roughly 10x fewer parameters,
  • was trained on a filtered dataset.

You should get outputs similar to the third row of Figure 9.

Figure 9
Caption

From a user perspective, the main benefit of GLIDE is that it is much faster than the CLIP-guided methods which I have tried so far.

Is the base the only checkpoint available for the base diffusion model?

I think so. From what I can see in the code below, there are 6 checkpoints:

  • two for classifier-free guidance (sampling and upsampling),
  • two for inpainting (sampling and upsampling),
  • two for CLIP (text encoding and image encoding).

MODEL_PATHS = {
"base": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt",
"upsample": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt",
"base-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base_inpaint.pt",
"upsample-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample_inpaint.pt",
"clip/image-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_image_enc.pt",
"clip/text-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_text_enc.pt",
}

from glide-text2im.

woctezuma avatar woctezuma commented on July 16, 2024 2

I see the nice following commits:

  • 146bd9c add install command to notebooks -> git and pip at the start of the notebook,
  • f468908 add colab links -> Colab badges for the links in the README,
  • 9cc8e56 colab GPU backend -> GPU support toggled ON.

from glide-text2im.

woctezuma avatar woctezuma commented on July 16, 2024 1

I can see that the sampling part is slightly different than yours, adding the model_fn function to the sample loop. Is this related to the fact that they just do free guidance (cond_fn=None) rather than clip guidance like in your colab?

To clarify any confusion:

Unless I am missing something, the model_fn function is added to the sample loop in both notebooks called text2im.ipynb.

# Sample from the base model.
model.del_cache()
samples = diffusion.p_sample_loop(
    model_fn,
    (full_batch_size, 3, options["image_size"], options["image_size"]),
    device=device,
    clip_denoised=True,
    progress=True,
    model_kwargs=model_kwargs,
    cond_fn=None,
)[:batch_size]
model.del_cache()

Also, I have tried to combine the last two, and results seems to be better, like if clip guidance, for the small model introduces too much randomness. Any idea why?

I need to see the diff of what you did to understand better.

I would be glad to test this and see the results, if they are better. :) The black cat with white paws looks nice. 👍

from glide-text2im.

loretoparisi avatar loretoparisi commented on July 16, 2024 1

Thanks! I have two versions, this one

samples = diffusion.p_sample_loop(
    model,
    (batch_size, 3, options["image_size"], options["image_size"]),
    device=device,
    clip_denoised=True,
    progress=True,
    model_kwargs=model_kwargs,
    cond_fn=cond_fn,
)

where

cond_fn = clip_model.cond_fn([prompt] * batch_size, guidance_scale)

and in the latest colab from the repo

samples = diffusion.p_sample_loop(
    model_fn,
    (full_batch_size, 3, options["image_size"], options["image_size"]),
    device=device,
    clip_denoised=True,
    progress=True,
    model_kwargs=model_kwargs,
    cond_fn=None,
)[:batch_size]

with cond_fn=None and as model_fn

def model_fn(x_t, ts, **kwargs):
    half = x_t[: len(x_t) // 2]
    combined = th.cat([half, half], dim=0)
    model_out = model(combined, ts, **kwargs)
    eps, rest = model_out[:, :3], model_out[:, 3:]
    cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
    eps = th.cat([half_eps, half_eps], dim=0)
    return th.cat([eps, rest], dim=1)

from glide-text2im.

loretoparisi avatar loretoparisi commented on July 16, 2024

@woctezuma thanks!!! Is the base the only checkpoint available for the base diffusion model? I cannot reproduce the same results as showed in paper images "Figure 1" with the same text prompt.
In the references I can see also CLIP guided diffusion models for both 2566x256 and 512x512.

Crowson, K. Clip guided diffusion hq 256x256
https://colab.research.google.com/drive/12a_Wrfi2_gwwAuN3VvMTwVMz9TfqctNj,
2021a.Crowson,K.Clip guided diffusion 512x512,secondarymodelmethod
https://twitter.com/RiversHaveWings/status/1462859669454536711, 2021b.

from glide-text2im.

loretoparisi avatar loretoparisi commented on July 16, 2024

@woctezuma thanks! I can see that the sampling part is slightly different than yours, adding the model_fn function to the sample loop. Is this related to the fact that they just do free guidance (cond_fn=None) rather than clip guidance like in your colab? Also, I have tried to combine the last two, and results seems to be better, like if clip guidance, for the small model introduces too much randomness. Any idea why?

# Create the text tokens to feed to the model.
tokens = model.tokenizer.encode(prompt)
tokens, mask = model.tokenizer.padded_tokens_and_mask(
    tokens, options['text_ctx']
)

# Create the classifier-free guidance tokens (empty)
full_batch_size = batch_size * 2
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
    [], options['text_ctx']
)

# Pack the tokens together into model kwargs.
model_kwargs = dict(
    tokens=th.tensor(
        [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
    ),
    mask=th.tensor(
        [mask] * batch_size + [uncond_mask] * batch_size,
        dtype=th.bool,
        device=device,
    ),
)

# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
    half = x_t[: len(x_t) // 2]
    combined = th.cat([half, half], dim=0)
    model_out = model(combined, ts, **kwargs)
    eps, rest = model_out[:, :3], model_out[:, 3:]
    cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
    eps = th.cat([half_eps, half_eps], dim=0)
    return th.cat([eps, rest], dim=1)

# Sample from the base model.
model.del_cache()
samples = diffusion.p_sample_loop(
    model_fn,
    (full_batch_size, 3, options["image_size"], options["image_size"]),
    device=device,
    clip_denoised=True,
    progress=True,
    model_kwargs=model_kwargs,
    cond_fn=None,
)[:batch_size]
model.del_cache()

# Show the output
show_images(samples)

cat2

from glide-text2im.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.