Coder Social home page Coder Social logo

min-dalle's Introduction

min(DALL·E)

Colab   Hugging Face Spaces   Replicate   Discord

YouTube Walk-through by The AI Epiphany

This is a fast, minimal port of Boris Dayma's DALL·E Mini (with mega weights). It has been stripped down for inference and converted to PyTorch. The only third party dependencies are numpy, requests, pillow and torch.

To generate a 3x3 grid of DALL·E Mega images it takes:

  • 55 sec with a T4 in Colab
  • 33 sec with a P100 in Colab
  • 15 sec with an A10G on Hugging Face

Here's a more detailed breakdown of performance on an A100. Credit to @technobird22 and his NeoGen discord bot for the graph.
min-dalle

The flax model and code for converting it to torch can be found here.

Install

$ pip install min-dalle

Usage

Load the model parameters once and reuse the model to generate multiple images.

from min_dalle import MinDalle

model = MinDalle(
    models_root='./pretrained',
    dtype=torch.float32,
    device='cuda',
    is_mega=True, 
    is_reusable=True
)

The required models will be downloaded to models_root if they are not already there. Set the dtype to torch.float16 to save GPU memory. If you have an Ampere architecture GPU you can use torch.bfloat16. Set the device to either "cuda" or "cpu". Once everything has finished initializing, call generate_image with some text as many times as you want. Use a positive seed for reproducible results. Higher values for supercondition_factor result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top_k most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the temperature. If is_seamless is true, the image grid will be tiled in token space not pixel space.

image = model.generate_image(
    text='Nuclear explosion broccoli',
    seed=-1,
    grid_size=4,
    is_seamless=False,
    temperature=1,
    top_k=256,
    supercondition_factor=32,
    is_verbose=False
)

display(image)

min-dalle

Credit to @hardmaru for the example

Saving Individual Images

The images can also be generated as a FloatTensor in case you want to process them manually.

images = model.generate_images(
    text='Nuclear explosion broccoli',
    seed=-1,
    grid_size=3,
    is_seamless=False,
    temperature=1,
    top_k=256,
    supercondition_factor=16,
    is_verbose=False
)

To get an image into PIL format you will have to first move the images to the CPU and convert the tensor to a numpy array.

images = images.to('cpu').numpy()

Then image $i$ can be coverted to a PIL.Image and saved

image = Image.fromarray(images[i])
image.save('image_{}.png'.format(i))

Progressive Outputs

If the model is being used interactively (e.g. in a notebook) generate_image_stream can be used to generate a stream of images as the model is decoding. The detokenizer adds a slight delay for each image. Set progressive_outputs to True to enable this. An example is implemented in the colab.

image_stream = model.generate_image_stream(
    text='Dali painting of WALL·E',
    seed=-1,
    grid_size=3,
    progressive_outputs=True,
    is_seamless=False,
    temperature=1,
    top_k=256,
    supercondition_factor=16,
    is_verbose=False
)

for image in image_stream:
    display(image)

min-dalle

Command Line

Use image_from_text.py to generate images from the command line.

$ python image_from_text.py --text='artificial intelligence' --no-mega

min-dalle

min-dalle's People

Contributors

20kdc avatar ak391 avatar andrewginns avatar chenxwh avatar ewpratten avatar haydn-jones avatar interfect avatar jedahan avatar kanttouchthis avatar kuprel avatar neverix avatar osanseviero avatar raphant avatar rupa avatar w4ffl35 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

min-dalle's Issues

High memory usage of torch compared to flax

Hi,

Thanks for this project, you gave me hope to give my AMD GPU some fun :-) (torch is much easier to get to work on it than XLA)

That being said, it doesn't work yet, because the torch variant eats much more RAM than the flax one:

/usr/bin/time -v python image_from_text.py --text='a comfy chair that looks like an avocado' --torch --seed=4
        Maximum resident set size (kbytes): 21.096.396

/usr/bin/time -v python image_from_text.py --text='a comfy chair that looks like an avocado' --seed=4
        Maximum resident set size (kbytes): 4.152.744

Would you have some pointers as to how to optimize that?

msgpack.exceptions.ExtraData: unpack(b) received extra data.

Trying to run the sample

python3 image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=4

on a new local installation on macOS Monterey yields the following error:

detokenizing image
Traceback (most recent call last):
  File "/Users/user/code/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/user/code/min-dalle/min_dalle/generate_image.py", line 75, in generate_image_from_text
    image = detokenize_torch(torch.tensor(image_tokens))
  File "/Users/user/code/min-dalle/min_dalle/min_dalle_torch.py", line 108, in detokenize_torch
    params = load_vqgan_torch_params(model_path)
  File "/Users/user/code/min-dalle/min_dalle/load_params.py", line 12, in load_vqgan_torch_params
    params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read())
  File "/usr/local/lib/python3.9/site-packages/flax/serialization.py", line 350, in msgpack_restore
    state_dict = msgpack.unpackb(
  File "msgpack/_unpacker.pyx", line 201, in msgpack._cmsgpack.unpackb
msgpack.exceptions.ExtraData: unpack(b) received extra data.

Is anyone running the sample prompt successfully on Mac? It is unclear to me wether the error is library or data related. Has anyone a similar problem?

ValueError: Unpack failed: incomplete input

Got this error on Ubuntu 20.04 with Python 3.8.10.

$ python image_from_text.py --text 'alien life' --seed 7                                                                                 ✔  14:15:01 
Namespace(image_path='generated', image_token_count=256, mega=False, seed=7, text='alien life', torch=False)
parsing metadata from ./pretrained/dalle_bart_mini
tokenizing text
['Ġalien']
['Ġlife']
text tokens [0, 8925, 742, 2]
Traceback (most recent call last):
  File "image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/home/jchia/gh/min-dalle/min_dalle/generate_image.py", line 54, in generate_image_from_text
    params_dalle_bart = load_dalle_bart_flax_params(model_path)
  File "/home/jchia/gh/min-dalle/min_dalle/load_params.py", line 44, in load_dalle_bart_flax_params
    params = serialization.msgpack_restore(f.read())
  File "/home/jchia/venv/pt/lib/python3.8/site-packages/flax/serialization.py", line 350, in msgpack_restore
    state_dict = msgpack.unpackb(
  File "msgpack/_unpacker.pyx", line 205, in msgpack._cmsgpack.unpackb
ValueError: Unpack failed: incomplete input

pretrained/dalle_bart_mini/flax_model.msgpack was somehow not deserializing properly.

Using --torch flag results in RuntimeError exception

Command run

python image_from_text.py --text='harry potter vampire' --torch

Result

Fatal exception, unable to continue script execution.

Error raised

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)

Full stacktrace

Traceback (most recent call last):
  File "image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/mnt/g/Documents/Projects/min_dalle/min_dalle/generate_image.py", line 58, in generate_image_from_text
    image_tokens[:image_token_count] = generate_image_tokens_torch(
  File "/mnt/g/Documents/Projects/min_dalle/min_dalle/min_dalle_torch.py", line 89, in generate_image_tokens_torch
    encoder_state = encode_torch(
  File "/mnt/g/Documents/Projects/min_dalle/min_dalle/min_dalle_torch.py", line 40, in encode_torch
    encoder_state = encoder(text_tokens)
  File "/home/USER/envs/mindalle/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/g/Documents/Projects/min_dalle/min_dalle/models/dalle_bart_encoder_torch.py", line 133, in forward
    self.embed_tokens.forward(text_tokens) +
  File "/home/USER/envs/mindalle/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 158, in forward
    return F.embedding(
  File "/home/USER/envs/mindalle/lib/python3.8/site-packages/torch/nn/functional.py", line 2199, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)

System environment

  • Ubuntu 20 on WSL
  • python3 via venv
  • Latest code on main branch

I'm guessing we need GPU specific models?

Add PyPI install

First off, great project, thank you for building this!

Have you considered building a package for PyPI? The models are available on the Hugging Face Hub and the Hugging Face Hub package can be used to programmatically download/cache the models.

Another idea is that encoder.pt and decoder.pt could also be uploaded to a new model on the Hub so they don't need to be dynamically converted at runtime for the Torch model.

One last thing to consider is changing prints to log statements, so they could be turned off if desired.

General idea is someone can run pip install then be up and running regardless of the runtime platform (Windows/Linux/macOS).

BusError only with mega model

Running the command from the README:

python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=4

gives me this output:

Namespace(mega=True, torch=False, text='a comfy chair that looks like an avocado', seed=4, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
['Ġthat']
['Ġlooks']
['Ġlike']
['Ġan']
['Ġavocado']
text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]
loading flax encoder
encoding text tokens
2022-06-28 11:32:39.180526: E external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc:202] Unable to resolve runtime symbol: `___extendhfsf2'.  Hint: if the symbol a custom call target, make sure you've registered it with the JIT using XLA_CPU_REGISTER_CUSTOM_CALL_TARGET.
JIT session error: Symbols not found: [ ___extendhfsf2 ]
Bus error: 10

This only happens with the mega model.

CUDA error: device-side assert triggered

When running in high load settings on the A100 this error comes up after about 10 minutes. This happens on replicate and on the discord bot. The best fix so far is just to restart the server every 10 minutes. If anyone has a better fix for it, please post it here. Thanks

CUDA out of memory: seems to always allocate all of it.

Hi,

I can't get the model working using the replicate image (r8.im/kuprel/min-dalle@sha256:71b9ef81385fae73b632d7e2fe0f5988a739781e833a610a9c83bc45205d8215) on any GPU because of OOM errors. I tried on progressively increased GPU VRAM, from 16 GB to 48 GB. Does it really require more or something is wrong ?

On a RTX A6000 with 48GB Video Memory (cloud docker containers), when requesting a prediction:
2022-07-13T11:57:56.509707773Z RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 47.54 GiB total capacity; 45.68 GiB already allocated; 3.56 MiB free; 46.44 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Doesn't use both threads on each core when running on CPU

I'm using a machine with an 8 core, 16 thread CPU, and plenty of CPU memory, but no GPU compatible with current ML toolkits; ROCm dropped support for my hardware a while ago. So I want to run this on CPU as efficiently as I can.

Unfortunately, it seems like I am only getting about half about 80% of the performance I think I ought to be out of the CPU backend.

When I run time python image_from_text.py --text='alien life' --seed=7 --no-torch, it seems to only be able to use one thread on each two-thread core, gets up to about 500% CPU in htop, and reports:

real	1m9.092s
user	5m16.416s
sys	0m10.788s

When I run time python image_from_text.py --text='alien life' --seed=7 --torch, I did manage to catch it at more like 700% CPU. It runs a bit faster but still doesn't seem to be fully using my CPU:

real	0m51.015s
user	4m54.248s
sys	0m14.651s

I also get this different and much more terrifying image; I figured the same seed would produce the same result with both engines, but I was wrong.
generated

Anyway, I would expect CPU usage to be closer to 16000%, and user + sys times to be more like 16x real times, if I was actually managing to use both threads on each of the 8 physical cores at full tilt.

Is there something about the backend that is causing it to only try and use one thread per full core, and not one thread per hardware thread? Is that something that I can change?

cannot download datasets after login

when running the setup script it does not seem like I can download the models from wandb

wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0

wandb: Downloading dataset artifact dalle-mini/dalle-mini/mini-1:v0
Traceback (most recent call last):
  File "/usr/local/bin/wandb", line 8, in <module>
    sys.exit(cli())
  File "/usr/local/lib/python3.9/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "/usr/local/lib/python3.9/site-packages/click/core.py", line 1657, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/usr/local/lib/python3.9/site-packages/click/core.py", line 1657, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/usr/local/lib/python3.9/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/usr/local/lib/python3.9/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/wandb/cli/cli.py", line 96, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/wandb/cli/cli.py", line 1657, in get
    path = artifact.download(root=root)
  File "/usr/local/lib/python3.9/site-packages/wandb/apis/public.py", line 3867, in download
    manifest = self._load_manifest()
  File "/usr/local/lib/python3.9/site-packages/wandb/apis/public.py", line 4135, in _load_manifest
    with requests.get(index_file_url) as req:
AttributeError: __enter__

Cannot install flax because package versions have conflicting dependencies

I'm getting this error when I try and install the requirements with pip:

ERROR: Cannot install flax because these package versions have conflicting dependencies.

The conflict is caused by:
    optax 0.1.2 depends on jaxlib>=0.1.37
    optax 0.1.1 depends on jaxlib>=0.1.37
    optax 0.1.0 depends on jaxlib>=0.1.37
    optax 0.0.91 depends on jaxlib>=0.1.37
    optax 0.0.9 depends on jaxlib>=0.1.37
    optax 0.0.8 depends on jaxlib>=0.1.37
    optax 0.0.6 depends on jaxlib>=0.1.37
    optax 0.0.5 depends on jaxlib>=0.1.37
    optax 0.0.3 depends on jaxlib>=0.1.37
    optax 0.0.2 depends on jaxlib>=0.1.37
    optax 0.0.1 depends on jaxlib>=0.1.37

To fix this you could try to:
1. loosen the range of package versions you've specified
2. remove package versions to allow pip attempt to solve the dependency conflict

ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts

I'm not sure what the fix is here as I'm not familiar with optax or jaxlib. Could someone clarify?

I'm on Windows, using Python 3.9.13 and pip 22.1.2.

Streaming intermediate images?

Is it possible to publish an update of the model that supports streaming intermediate images during reverse diffusion ie with an iterator? Would greatly help UX if the user can see their image form while they're waiting for the process to finish.

Add requirements.txt file

I think it is best practice to maintain a requirements.txt file so as to track which libraries have changed and to ensure that a stable version of min-dalle can be installed (installing the latest torch in setup_torch.sh with no other requirements frozen can cause stability issues for apps and sites that use min-dalle

Branches + Versions for maintainability

Is it possible to maintain feature/bug-fix branches and a corresponding stable branch? A lot of projects will be using this repository and thus frequent API changes might break the code quickly. Releasing versions might also be useful since the changes can be documented as part of the changelog/release-notes. (especially breaking changes).

TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

When running python image_from_text.py --text='court sketch of godzilla on trial' --mega --seed=100 or anything else with --mega

Output is the following:

Namespace(mega=True, torch=False, text='court sketch of godzilla on trial', seed=100, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġcourt']
['Ġsketch']
['Ġof']
['Ġgodzilla']
['Ġon']
['Ġtrial']
text tokens [0, 2634, 4189, 111, 14450, 133, 5167, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
Traceback (most recent call last):
  File "/Users/REDACTED/workspace/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1498, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1484, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 82, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 1159, in apply
    return apply(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/scope.py", line 831, in wrapper
    y = fn(root, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 1535, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 314, in wrapped_fn
    ret = trafo_fn(module_scopes, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/lift.py", line 218, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/lift.py", line 770, in inner
    broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/lift.py", line 754, in scanned
    c, y = fn(scope, c, *args)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 307, in core_fn
    res = fn(cloned, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/slicing.py", line 147, in dynamic_update_slice
    return dynamic_update_slice_p.bind(operand, update, *start_indices)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/core.py", line 323, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/core.py", line 326, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 166, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 177, in default_process_primitive
    out_aval, effects = primitive.abstract_eval(*avals, **params)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/core.py", line 359, in abstract_eval_
    return abstract_eval(*args, **kwargs), no_effects
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
    dtype_rule(*avals, **kwargs), weak_type=weak_type,
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/slicing.py", line 933, in _dynamic_update_slice_dtype_rule
    lax._check_same_dtypes("dynamic_update_slice", False, operand.dtype,
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 4373, in _check_same_dtypes
    raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/REDACTED/workspace/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

Running on M1 Macbook Pro.

Is inpainting a possibility?

Hey, amazing work! Would there be a possibility of integrating text guided image completion or in panting with this model? If so, what would be the general direction to go to for implementation?

I'm assuming the model needs to be fed a rough estimate of an image's latent data, then fill in the missing pieces. I can give it a go if there are leads in the right direction, Thanks!

Add Docker compatibility

Adding instructions for running min(DALL·E) in Docker can solve most issues caused by differences in environment and version numbers.

CUDA out of memory

RuntimeError: CUDA out of memory. Tried to allocate 32.00 MiB (GPU 0; 6.00 GiB total capacity; 5.32 GiB already allocated; 0 bytes free; 5.32 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Is it possible to run this neural network on a card with six gigs of vram, somehow?

CUDA not available

When I run image_from_text.py I get an error saying "User provided device_type of 'cuda', but CUDA is not available. Disabling" in autocast_mode.py

I have an RTX 3070 Ti so CUDA should work.

setup.sh fails due to possible CRLF issue

Hi!
Was trying to run this project in Ubuntu 20.04.04 (in WSL 2 under Windows 11), and after executing sh setup.sh received the following:

efim@DESKTOP-PSEQ17Q:/mnt/f/Projects/min-dalle$ sh setup.sh
: not found:
ERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt\r'
: not found:
: not found:
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   234  100   234    0     0    390      0 --:--:-- --:--:-- --:--:--   389
: No ng: Failed to create the file ./pretrained/vqgan/flax_model.msgpack
Warning: such file or directory
  0  290M    0 15944    0     0  19163      0  4:24:39 --:--:--  4:24:39 19163
curl: (23) Failed writing body (0 != 15944)
: not found:
setup.sh: 11: wandb: not found
setup.sh: 12: wandb: not found
setup.sh: 13: wandb: not found

finetune code?

Good job, do you have any plans to add fine-tuned code?

OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized.

python image_from_text.py --text='a comfy chair that looks like an avocado' --seed=4
OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized.
OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can 
degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into 
the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workarou
nd you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause c
rashes or silently produce incorrect results. For more information, please see http://www.intel.com/software/products/support/.

Anyone else has seen this issue previously? The suggested remedy seems to work.

Intel MBP 16"

--mega / is_mega raises TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

I tried running the following in the Google Colab:

image = generate_image_from_text("court sketch of godzilla on trial", is_mega=True, seed=100)

This caused an exception:

parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġcourt']
['Ġsketch']
['Ġof']
['Ġgodzilla']
['Ġon']
['Ġtrial']
text tokens [0, 2634, 4189, 111, 14450, 133, 5167, 2]
loading flax encoder
encoding text tokens
loading flax decoder
sampling image tokens
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-5-53d46ed9885c>](https://localhost:8080/#) in <module>()
      2 
----> 3 image = generate_image_from_text("court sketch of godzilla on trial", is_mega=True, seed=100)
      4 display(image)

67 frames
UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
[/content/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py](https://localhost:8080/#) in __call__(self, decoder_state, keys_state, values_state, attention_mask, state_index)
     38             keys_state,
     39             self.k_proj(decoder_state).reshape(shape_split),
---> 40             state_index
     41         )
     42         values_state = lax.dynamic_update_slice(

TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

The same thing happened when I tried running the command-line locally:

python3 image_from_text.py --text='court sketch of godzilla on trial' --mega --seed=100

NOTE: I had to add the following line to the Setup block of the Jupyter code:

! wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14

inference on a batch

Hello,
I want to ask if it possible to inference the model on a batch of texts instead of one text only? my application excepts several text prompts in one time and it could be nice to do them in one batch and speed the things up.

Commit ed91ab4a30ffdcf7e8773e6b434816f79f5fead8 causes mega to fail

Prior to commit ed91ab4

python image_from_text.py --text="bart simpson on mars" --mega --seed=666 --image_path=bart-simpson-on-mars-666 --torch

That command would take approximately 1 minute to run from start to finish. After that commit, the command hangs on loading decoder and eventually (after approximately 2 minutes) returns Killed.

Text Tokenizer is fragmenting words

I'm running into unexpected behavior of the text tokenizer, running this on Windows, Python 3.7 , in a virtual environment, using the supplied image_from_text.py script file.

The input text is tokenized in a way that breaks up the words, thus preventing the output from actually depicting what was requested:

'a comfy chair that looks like an avocado' ->

tokenizing text
['Ġ', 'a']
['Ġ', 'com', 'fy']
['Ġ', 'chair']
['Ġ', 'th', 'at']
['Ġ', 'look', 's']
['Ġ', 'like']
['Ġ', 'an']
['Ġ', 'av', 'oc', 'ado']
text tokens [0, 3, 28, 3, 157, 10065, 3, 10022, 3, 184, 73, 3, 7003, 46, 3, 19831, 3, 65, 3, 178, 158, 1165, 2]

'alien life' ->

tokenizing text
['Ġ', 'al', 'ien']
['Ġ', 'life']
text tokens [0, 3, 71, 1385, 3, 3210, 2]

Since the wrong tokens were chosen, the model returns a generic gamer chair for the first prompt, and some petri dish for the second, which is expected given the garbled tokens.

I checked that the tokenizer.json files were downloaded correctly for both the mini and mega models and they are - manually searching for the words in them finds them in there without any issue.

Is there a specific dependency for the text tokenizer that I'm unaware of or is this simply a bug?

Input your own gen_top_p, gen_top_k parameters

In the replicate demo you are already able to manually input the seed but I would like also to be able to input your own gen_top_p, gen_top_k parameters. For example for the same seed and prompt I want to be able to try different gen_top_p, gen_top_k values and compare the results.

Make random seed monotonic in grid

If the random seed is '2', and the grid size is '2', then the base seeds for each image should be [2,3,4,5]. This allows you to select one image for a blow-up render.

Is this possible?

Execution stops with message "killed"

I have a thinkpad X220 (no GPU :D ), i5 2520M 8GB.
Running archlinux.

Setup.sh completed with no issue.

First two attempts with
python image_from_text.py --text='alien life' --seed=7

Namespace(mega=False, torch=False, text='alien life', seed=7, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mini
tokenizing text
['Ġalien']
['Ġlife']
text tokens [0, 8925, 742, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
Traceback (most recent call last):
  File "/home/peter/minidalle/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/home/peter/minidalle/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/home/peter/minidalle/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/home/peter/minidalle/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/home/peter/minidalle/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 255, in sample_image_tokens
    _, image_tokens = lax.scan(

[...]

raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/peter/minidalle/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/home/peter/minidalle/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/home/peter/minidalle/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/home/peter/minidalle/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/home/peter/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/peter/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/peter/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home/peter/minidalle/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 255, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/home/peter/minidalle/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 214, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/home/peter/minidalle/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/home/peter/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/home/peter/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/home/peter/minidalle/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/home/peter/minidalle/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.

Other inputs lead to the execution finishing with a "killed" message:

[peter@peter-arcox220 min-dalle]$ python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=4
Namespace(mega=True, torch=False, text='a comfy chair that looks like an avocado', seed=4, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
['Ġthat']
['Ġlooks']
['Ġlike']
['Ġan']
['Ġavocado']
text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]
Killed
[peter@peter-arcox220 min-dalle]$ python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=4
Namespace(mega=True, torch=False, text='a comfy chair that looks like an avocado', seed=4, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
['Ġthat']
['Ġlooks']
['Ġlike']
['Ġan']
['Ġavocado']
text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]
Killed
[peter@peter-arcox220 min-dalle]$ python image_from_text.py --text='a comfy chair that looks like an avocado' --torch --seed=4
Namespace(mega=False, torch=True, text='a comfy chair that looks like an avocado', seed=4, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mini
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
['Ġthat']
['Ġlooks']
['Ġlike']
['Ġan']
['Ġavocado']
text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]
loading torch encoder
encoding text tokens
loading torch decoder
sampling image tokens
image token 0 is 23
image token 1 is 8867
image token 2 is 15149
image token 3 is 10225
image token 4 is 6271
[...]
image token 74 is 4319
image token 75 is 14420
image token 76 is 9720
image token 77 is 7781
image token 78 is 8583
image token 79 is 5401
Killed
[peter@peter-arcox220 min-dalle]$ python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=4
Namespace(mega=True, torch=False, text='a comfy chair that looks like an avocado', seed=4, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
['Ġthat']
['Ġlooks']
['Ġlike']
['Ġan']
['Ġavocado']
text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]
Killed

Incorrect dtypes error with the Mega model

Hey, I'm seeing the following error when passing the '--mega' option to use the mega model

(min-dalle) ➜  min-dalle git:(main) ✗ python image_from_text.py --text="a comfy chair that looks like an avocado" --mega

/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lib/__init__.py:34: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Namespace(image_path='generated', image_token_count=256, mega=True, seed=0, text='a comfy chair that looks like an avocado', torch=False)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
['Ġthat']
['Ġlooks']
['Ġlike']
['Ġan']
['Ġavocado']
text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
Traceback (most recent call last):
  File "image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1498, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1484, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 82, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 1159, in apply
    return apply(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/scope.py", line 831, in wrapper
    y = fn(root, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 1535, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 314, in wrapped_fn
    ret = trafo_fn(module_scopes, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 218, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 770, in inner
    broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 754, in scanned
    c, y = fn(scope, c, *args)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 307, in core_fn
    res = fn(cloned, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 147, in dynamic_update_slice
    return dynamic_update_slice_p.bind(operand, update, *start_indices)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 323, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 326, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 166, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 177, in default_process_primitive
    out_aval, effects = primitive.abstract_eval(*avals, **params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 359, in abstract_eval_
    return abstract_eval(*args, **kwargs), no_effects
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
    dtype_rule(*avals, **kwargs), weak_type=weak_type,
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 933, in _dynamic_update_slice_dtype_rule
    lax._check_same_dtypes("dynamic_update_slice", False, operand.dtype,
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4373, in _check_same_dtypes
    raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

Changing image size?

Would love to know where to change the code to allow for larger image sizes than 256x256.
Better yet would be the ability to change them from the Colab.

msgpack.exceptions.ExtraData: unpack(b) received extra data.

Crashes out with msgpack.exceptions.ExtraData: unpack(b) received extra data.

  • Python 3.9.12
  • macOS 12.4
  • 2021 Macbook Pro M1 Pro
python image_from_text.py --text='alien life' --seed=7

/Users/samm/miniconda3/lib/python3.9/site-packages/jax/_src/lib/__init__.py:34: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Namespace(mega=False, torch=False, text='alien life', seed=7, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mini
tokenizing text
['Ġalien']
['Ġlife']
text tokens [0, 8925, 742, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
image tokens [6965, 6172, 1052, 14447, 6172, 12062, 15771, 2193, 10710, 4147, 1052, 6172, 2528, 14447, 5772, 8447, 6965, 14447, 14447, 14447, 11665, 6879, 15798, 9479, 910, 15303, 5605, 7542, 1052, 14447, 14447, 2528, 6965, 1052, 14447, 6078, 3386, 2519, 12838, 16017, 867, 8447, 11993, 12426, 11196, 14447, 14447, 2528, 6965, 14447, 14447, 7491, 16147, 13512, 8269, 271, 10397, 15945, 15945, 4903, 12892, 14447, 14447, 2528, 6965, 14447, 14447, 351, 358, 10362, 6001, 8612, 14037, 7864, 14246, 5201, 2810, 14447, 14447, 2528, 6965, 14447, 14447, 10549, 15618, 11792, 13401, 16223, 1464, 12861, 6992, 572, 601, 14447, 14447, 2528, 6965, 14447, 14447, 14447, 13183, 194, 14633, 1994, 10912, 2778, 5495, 12187, 2528, 14447, 14447, 2528, 6965, 14447, 14447, 14447, 2528, 14068, 4054, 5071, 1948, 5286, 7771, 12062, 12016, 14447, 14447, 2528, 6965, 14447, 14447, 14447, 7504, 15433, 7781, 4816, 12062, 663, 3812, 8447, 8173, 14447, 14447, 2528, 6965, 14447, 14447, 6078, 13401, 6790, 2813, 10121, 4301, 4811, 5984, 3851, 8493, 14447, 14447, 2528, 6965, 14447, 14447, 4465, 12509, 4238, 12290, 10543, 8222, 11348, 13909, 5919, 6965, 14447, 14447, 2528, 11591, 14447, 6172, 11665, 9501, 2810, 9570, 7781, 910, 10549, 4395, 10639, 16147, 8173, 14164, 2528, 11591, 14164, 11993, 11610, 15891, 6242, 1936, 14602, 4903, 3583, 11574, 7516, 12892, 8173, 14447, 2528, 11591, 7467, 5243, 13157, 2810, 6790, 16017, 7236, 4301, 11725, 10689, 11941, 12659, 8173, 1052, 2528, 6965, 6598, 4465, 4816, 2895, 11820, 3132, 15917, 1811, 4904, 6933, 6690, 4811, 7504, 2528, 11605, 7467, 4815, 351, 6948, 10228, 7771, 9479, 9213, 11196, 6628, 9897, 12480, 5885, 14247, 5772, 5772]
detokenizing image
Traceback (most recent call last):
  File "/Users/samm/git/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/samm/git/min-dalle/min_dalle/generate_image.py", line 74, in generate_image_from_text
    image = detokenize_torch(image_tokens)
  File "/Users/samm/git/min-dalle/min_dalle/min_dalle_torch.py", line 107, in detokenize_torch
    params = load_vqgan_torch_params(model_path)
  File "/Users/samm/git/min-dalle/min_dalle/load_params.py", line 11, in load_vqgan_torch_params
    params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read())
  File "/Users/samm/miniconda3/lib/python3.9/site-packages/flax/serialization.py", line 350, in msgpack_restore
    state_dict = msgpack.unpackb(
  File "msgpack/_unpacker.pyx", line 201, in msgpack._cmsgpack.unpackb
msgpack.exceptions.ExtraData: unpack(b) received extra data.

WARNING:absl:No GPU/TPU found, falling back to CPU

$ TF_CPP_MIN_LOG_LEVEL=0 python3 image_from_text.py --text='food' --seed=7
...
2022-06-28 20:30:10.925294: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:174] XLA service 0x5688680 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2022-06-28 20:30:10.925364: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:182]   StreamExecutor device (0): Interpreter, <undefined>
2022-06-28 20:30:10.935443: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:176] TfrtCpuClient created.
2022-06-28 20:30:10.936503: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

I've GPU, why it's not finding it?

Use a virtual environment to install dependecies

A short gaze at setup.sh told me that this script installs the dependencies in requirements.txt directly to the systems python installation. This is typically not considered good practise (as far as I know), so it might be a good idea to prompt users for permission first.

Alternatively installing everything to a venv might be a different solution that keeps the system python "clean".

Recent commit results in higher VRAM usage

The commit in question is ed91ab4. My system has 8GB of VRAM and is using Torch with ROCm support.

Here are my results with that commit and another from the one before it:

$ python image_from_text.py --text="alien life" --mega --torch --seed 100
Namespace(mega=True, torch=True, text='alien life', seed=100, image_path='generated', sample_token_count=256)
reading files from pretrained/dalle_bart_mega
initializing MinDalleTorch
loading encoder
loading decoder
Traceback (most recent call last):
  File "/home/user/min-dalle/image_from_text.py", line 68, in <module>
    generate_image(
  File "/home/user/min-dalle/image_from_text.py", line 48, in generate_image
    image_generator = MinDalleTorch(is_mega, sample_token_count)
  File "/home/user/min-dalle/min_dalle/min_dalle_torch.py", line 60, in __init__
    self.decoder = self.decoder.cuda()
  File "/home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 706, in cuda
    return self._apply(lambda t: t.cuda(device))
  File "/home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 595, in _apply
    module._apply(fn)
  File "/home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 595, in _apply
    module._apply(fn)
  File "/home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 595, in _apply
    module._apply(fn)
  [Previous line repeated 1 more time]
  File "/home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 618, in _apply
    param_applied = fn(param)
  File "/home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 706, in <lambda>
    return self._apply(lambda t: t.cuda(device))
RuntimeError: HIP out of memory. Tried to allocate 32.00 MiB (GPU 0; 7.98 GiB total capacity; 7.92 GiB already allocated; 58.00 MiB free; 7.93 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_HIP_ALLOC_CONF
$ git checkout 1ef9b0b
Previous HEAD position was ed91ab4 refactored to load models once and run multiple times
HEAD is now at 1ef9b0b added mega to colab
$ python image_from_text.py --text="alien life" --mega --torch --seed 100
Namespace(mega=True, torch=True, text='alien life', seed=100, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġalien']
['Ġlife']
text tokens [0, 8925, 742, 2]
loading torch encoder
encoding text tokens
loading torch decoder
sampling image tokens
detokenizing image
MIOpen(HIP): Warning [SQLiteBase] Missing system database file: gfx900_64.kdb Performance may degrade. Please follow instructions to install: https://github.com/ROCmSoftwarePlatform/MIOpen#installing-miopen-kernels-package
saving image to generated.png

thanks (it's 10x faster than JAX)!

I've been trying to get dalle-playground running performantly on M1, but there's a lot of work remaining to make the JAX model work via IREE/Vulkan.

so, I tried out your pytorch model,

with a recent nightly of pytorch:

pip install --pre "torch>1.13.0.dev20220610" "torchvision>0.14.0.dev20220609" --extra-index-url https://download.pytorch.org/whl/nightly/cpu

…and it's 10x faster at dalle-mega than dalle-playground was on JAX/XLA!

using dalle-mega full:

wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1:latest

generating 1 image took 27 mins on dalle-playground (using 117% CPU), whereas this pytorch model runs in 2.7 mins (using 145% CPU)!
GPU looks less-than-half utilized. haven't checked whether pytorch is the process that's using the GPU.

these measurements are from M1 Max.

bonus
"crystal maiden and lina enjoying a pint together at a tavern"
generated

Collab error output, related to pretrained/vqgan/flax_model.msgpack

Just quickly sharing an error output from the collab, impressive project!


FileNotFoundError Traceback (most recent call last)
in ()
1 from min_dalle.generate_image import generate_image_from_text
2
----> 3 image = generate_image_from_text("alien life", seed=7)
4 display(image)

2 frames
/content/min-dalle/min_dalle/generate_image.py in generate_image_from_text(text, is_mega, is_torch, seed, image_token_count)
72
73 if image_token_count == config['image_length']:
---> 74 image = detokenize_torch(image_tokens)
75 return Image.fromarray(image)
76 else:

/content/min-dalle/min_dalle/min_dalle_torch.py in detokenize_torch(image_tokens)
105 print("detokenizing image")
106 model_path = './pretrained/vqgan'
--> 107 params = load_vqgan_torch_params(model_path)
108 detokenizer = VQGanDetokenizer()
109 detokenizer.load_state_dict(params)

/content/min-dalle/min_dalle/load_params.py in load_vqgan_torch_params(path)
8
9 def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
---> 10 with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
11 params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read())
12

FileNotFoundError: [Errno 2] No such file or directory: './pretrained/vqgan/flax_model.msgpack'

.pretrained/vqgan/flax_model.msgpack does appear in the filetree

"watch -n 1 nvidia-smi" show I'm on a p100, if that helps.

Incompatible with Python 3.10

If I install the dependencies in a python 3.10 venv then I get the following error from flax when running image_from_text.py with no options:

  File "/home/matthew/Programming/Python/min-dalle/min_dalle/load_params.py", line 5, in <module>
    from flax import traverse_util, serialization
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/flax/__init__.py", line 18, in <module>
    from . import core as core
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast as broadcast
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/flax/core/axes_scan.py", line 19, in <module>
    import jax
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/__init__.py", line 35, in <module>
    from jax import config as _config_module
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/config.py", line 17, in <module>
    from jax._src.config import config
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/_src/config.py", line 29, in <module>
    from jax._src import lib
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 41, in <module>
    import scipy.signal as _signal
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/signal/__init__.py", line 302, in <module>
    from .filter_design import *
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/signal/filter_design.py", line 16, in <module>
    from scipy import special, optimize, fft as sp_fft
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/optimize/__init__.py", line 421, in <module>
    from ._shgo import shgo
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/optimize/_shgo.py", line 9, in <module>
    from scipy import spatial
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/spatial/__init__.py", line 107, in <module>
    from . import distance, transform
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/spatial/transform/__init__.py", line 19, in <module>
    from .rotation import Rotation, Slerp
ImportError: /home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/spatial/transform/rotation.cpython-310-x86_64-linux-gnu.so: undefined symbol: _PyGen_Send

If I downgrade the python version to 3.9 then it works.

TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32

Running python image_from_text.py --text='a comfy chair' --seed=7 shows the following error:

$  python image_from_text.py --text='a comfy chair' --seed=7                                                                                                                            

Namespace(mega=False, torch=False, text='a comfy chair', seed=7, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mini
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
text tokens [0, 58, 29872, 2408, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
Traceback (most recent call last):
  File "/home/papul/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/home/papul/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 255, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 1498, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 1484, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 82, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 214, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1159, in apply
    return apply(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/scope.py", line 831, in wrapper
    y = fn(root, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1535, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 314, in wrapped_fn
    ret = trafo_fn(module_scopes, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/lift.py", line 218, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/lift.py", line 770, in inner
    broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/lift.py", line 754, in scanned
    c, y = fn(scope, c, *args)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 307, in core_fn
    res = fn(cloned, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 147, in dynamic_update_slice
    return dynamic_update_slice_p.bind(operand, update, *start_indices)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/core.py", line 323, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/core.py", line 326, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 166, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 177, in default_process_primitive
    out_aval, effects = primitive.abstract_eval(*avals, **params)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/core.py", line 359, in abstract_eval_
    return abstract_eval(*args, **kwargs), no_effects
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
    dtype_rule(*avals, **kwargs), weak_type=weak_type,
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 933, in _dynamic_update_slice_dtype_rule
    lax._check_same_dtypes("dynamic_update_slice", False, operand.dtype,
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 4373, in _check_same_dtypes
    raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/papul/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/home/papul/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 255, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 214, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.

Huggingface hub integration

Any plans on integrating the model download with huggingface hub.

Instead of all models in single repo, we can have 2 repo , one each for mini and mega.

I can contribute this feature. Please let me know your thoughts on this.

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.