Coder Social home page Coder Social logo

shilin-lu / tf-icon Goto Github PK

View Code? Open in Web Editor NEW
769.0 35.0 100.0 77.05 MB

[ICCV 2023] "TF-ICON: Diffusion-Based Training-Free Cross-Domain Image Composition" (Official Implementation)

Home Page: https://shilin-lu.github.io/tf-icon.github.io/

License: MIT License

Python 96.59% Jupyter Notebook 3.26% HTML 0.15%
image-composition image-inversion generative-ai stable-diffusion text-to-image diffusion-model

tf-icon's People

Contributors

eltociear avatar shilin-lu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tf-icon's Issues

Questions Regrading the Implementation

Hi, thanks for releasing the code. It's very interesting work.
I am new to this topic. Therefore, I have several questions regarding the code that I hope you can help address.

  1. In the attention.py file , it can be seen that for the forward method in CrossAttention class, the set of parameters encode=False, controller_for_inject=None, inject=False, layernum=None, main_height=None, main_width=None is not used. I think this does not break the code since only the code snippets starting from this line are used since the arguments encode_uncon == False and decode_uncon == False are defaults in the apply_model method. But this leads me to think why these arguments are there. Guessing from its name, it may relate to the set of experiments, for example, Fig 3 and 17, that advocates the usage of DPM solver (with the exceptional prompt) to get the noise-corrupted latent code of the real images (background and foreground). If this is the case, do you mind releasing the code to help me play with different null-text inversion methods, i.e., the methods you compare against in Fig 3? Initially, I want to test on my own. So I followed your code to use DPMSolverSampler to get z_ref_enc, but could not find a way to "denoise" it back to the original image because when I set DPMencode=False in the DPMSolverSampler.sample, which is expected.

  2. I also have questions regarding the code design. I saw in ldm.models.diffusion.dpm_solver folder, you made a significant modification. For example, the encode method of the DPMSolverSampler in the original implementation was removed and streamlined in the sample method. Also, chunks of code are extracted from different places to create "helper functions" like low_order_sample. I am wondering are these changes made because of the manipulation of cross-attention map? Or there are some other reasons I am not aware of.

Thanks,
YD

difference between cross and same domain

I want to know what the difference is. Is it just a difference in scale? What impact will it have from an intuitive perspective? I don't seem to find this element in the paper.

diffusers version

great job, do you have any ideas about integrating into Diffusers?

Query regarding setting batch size

Hi! I couldn't find in the codebase if setting n_samples = 8 would automatically walk through a folder with 8 background and 8 foreground images while using the same prompt. The init_image and ref_image seem to be resized using the repeat() function. Would it be possible to please show an example using n_samples argument?

Also, thanks for open-sourcing this wonderful work!

About Metric Calculations

Thank you for your excellent work, and will the code for calculating the metric in your article be published?

Can a 3090 GPU run this code?

I tried to run this:

python scripts/main_tf_icon.py  --ckpt ./ckpt/v2-1_512-ema-pruned.ckpt      \
                                --root ./inputs/same_domain      \
                                --cross_domain False               \
                                --dpm_steps 20                    \
                                --dpm_order 2                     \
                                --scale 5                         \
                                --tau_a 0.4                       \
                                --tau_b 0.8                       \
                                --outdir ./outputs                \
                                --gpu cuda:0                      \
                                --seed 3407 

And got the error:

RuntimeError: CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 23.70 GiB total capacity; 19.30 GiB already allocated; 50.56 MiB free; 22.12 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Tunnel connection failed:403 Forbidden

Hello, thank you for your work. Now I want to replicate your work on my dataset. When I run the program, I get the following error:
(base) root@c5a0a7227fb0:/home/401229/TF-ICON-main# python scripts/main_tf_icon.py --ckpt ./ckpt/v2-1_512-ema-pruned.ckpt --root ./inputs/cross_domain --domain 'cross' --dpm_steps 20 --dpm_order 2 --scale 5 --tau_a 0.4 --tau_b 0.8 --outdir ./outputs --gpu cuda:0 --seed 3407 Loading model from ./ckpt/v2-1_512-ema-pruned.ckpt Global Step: 220000 No module 'xformers'. Proceeding without it. /root/anaconda3/lib/python3.11/site-packages/pytorch_lightning/utilities/distributed.py:258: LightningDeprecationWarning: pytorch_lightning.utilities.distributed.rank_zero_onlyhas been deprecated in v1.8.1 and will be removed in v2.0.0. You can import it frompytorch_lightning.utilities instead. rank_zero_deprecation( LatentDiffusion: Running in eps-prediction mode DiffusionWrapper has 865.91 M params. making attention of type 'vanilla' with 512 in_channels Working with z of shape (1, 4, 32, 32) = 4096 dimensions. making attention of type 'vanilla' with 512 in_channels 'HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin (Caused by ProxyError('Cannot connect to proxy.', OSError('Tunnel connection failed: 403 Forbidden')))' thrown while requesting HEAD https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin WARNING:huggingface_hub.utils._http:'HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin (Caused by ProxyError('Cannot connect to proxy.', OSError('Tunnel connection failed: 403 Forbidden')))' thrown while requesting HEAD https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin Traceback (most recent call last): File "/root/anaconda3/lib/python3.11/site-packages/urllib3/connectionpool.py", line 711, in urlopen self._prepare_proxy(conn) File "/root/anaconda3/lib/python3.11/site-packages/urllib3/connectionpool.py", line 1007, in _prepare_proxy conn.connect() File "/root/anaconda3/lib/python3.11/site-packages/urllib3/connection.py", line 374, in connect self._tunnel() File "/root/anaconda3/lib/python3.11/http/client.py", line 926, in _tunnel raise OSError(f"Tunnel connection failed: {code} {message.strip()}") OSError: Tunnel connection failed: 403 Forbidden
Is this because I don't have a VPN? Besides relying on v2-1-512-ema-pruned.ckpt model, will the program download open_clip_pytorch_model.bin? Please help me, thanks!

How much VRAM is needed for this?

This looks great!

But I ran into an out of memory error while running the code. The device I'm using is RTX 3090, 24G. Could you share some information on how much memory is needed to run this code successfully? Thanks!

The error message is as follows:

/home/wenhuaszhgc/miniconda3/envs/xqqpy38/bin/python3.8 /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/main_tf_icon.py --ckpt /home/wenhuaszhgc/users/xqq/TF-ICON/ckpt/v2-1_512-ema-pruned.ckpt --root /home/wenhuaszhgc/users/xqq/TF-ICON/inputs/cross_domain --domain cross --dpm_steps 20 --dpm_order 2 --scale 5 --tau_a 0.4 --tau_b 0.8 --outdir /home/wenhuaszhgc/users/xqq/TF-ICON/outputs --gpu cuda:2 --seed 3407 
Loading model from /home/wenhuaszhgc/users/xqq/TF-ICON/ckpt/v2-1_512-ema-pruned.ckpt
Global Step: 220000
/home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:258: LightningDeprecationWarning: `pytorch_lightning.utilities.distributed.rank_zero_only` has been deprecated in v1.8.1 and will be removed in v2.0.0. You can import it from `pytorch_lightning.utilities` instead.
  rank_zero_deprecation(
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 865.91 M params.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...
Global seed set to 3407
1
loaded input image of size (512, 512) from /home/wenhuaszhgc/users/xqq/TF-ICON/inputs/cross_domain/a pencil drawing of an eiffel tower in the distance, black and white painting/bg48.png
['a pencil drawing of an eiffel tower in the distance, black and white painting']
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/main_tf_icon.py:565 in <module>  │
│                                                                              │
│   562                                                                        │
│   563                                                                        │
│   564 if __name__ == "__main__":                                             │
│ ❱ 565 │   main()                                                             │
│   566                                                                        │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/main_tf_icon.py:519 in main      │
│                                                                              │
│   516 │   │   │   │   │   │   │   mask = torch.zeros_like(z_enc, device=devi │
│   517 │   │   │   │   │   │   │   mask[:, :, param[0]:param[1], param[2]:par │
│   518 │   │   │   │   │   │   │                                              │
│ ❱ 519 │   │   │   │   │   │   │   samples, _ = sampler.sample(steps=opt.dpm_ │
│   520 │   │   │   │   │   │   │   │   │   │   │   │   │   │   inv_emb=inv_em │
│   521 │   │   │   │   │   │   │   │   │   │   │   │   │   │   conditioning=c │
│   522 │   │   │   │   │   │   │   │   │   │   │   │   │   │   batch_size=opt │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ autograd/grad_mode.py:27 in decorate_context                                 │
│                                                                              │
│    24 │   │   @functools.wraps(func)                                         │
│    25 │   │   def decorate_context(*args, **kwargs):                         │
│    26 │   │   │   with self.clone():                                         │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                           │
│    28 │   │   return cast(F, decorate_context)                               │
│    29 │                                                                      │
│    30 │   def _wrap_generator(self, func):                                   │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/models/diffusion/dpm_solver/sampler. │
│ py:184 in sample                                                             │
│                                                                              │
│   181 │   │   │   │   # decoded background                                   │
│   182 │   │   │   │   ptp_utils.register_attention_control(self.model, orig_ │
│   183 │   │   │   │   │   │   │   │   │   │   │   │   │    width, height, to │
│ ❱ 184 │   │   │   │   orig = dpm_solver_decode.sample_one_step(orig, step, s │
│   185 │   │   │   │                                                          │
│   186 │   │   │   │   # decode for cross-attention                           │
│   187 │   │   │   │   ptp_utils.register_attention_control(self.model, cross │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:1289 in    │
│ sample_one_step                                                              │
│                                                                              │
│   1286 │   │   data['t_prev_list'][-1] = vec_t                               │
│   1287 │   │   # We do not need to evaluate the final model value.           │
│   1288 │   │   if step < steps:                                              │
│ ❱ 1289 │   │   │   data['model_prev_list'][-1] = self.model_fn(data['x'], ve │
│   1290 │   │                                                                 │
│   1291 │   │   del vec_t                                                     │
│   1292                                                                       │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:471 in     │
│ model_fn                                                                     │
│                                                                              │
│    468 │   │   Convert the model to the noise prediction model or the data p │
│    469 │   │   """                                                           │
│    470 │   │   if self.algorithm_type == "dpmsolver++":                      │
│ ❱  471 │   │   │   return self.data_prediction_fn(x, t, DPMencode=DPMencode, │
│    472 │   │   else:                                                         │
│    473 │   │   │   return self.noise_prediction_fn(x, t, DPMencode=DPMencode │
│    474                                                                       │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:459 in     │
│ data_prediction_fn                                                           │
│                                                                              │
│    456 │   │   """                                                           │
│    457 │   │   Return the data prediction model (with corrector).            │
│    458 │   │   """                                                           │
│ ❱  459 │   │   noise = self.noise_prediction_fn(x, t, DPMencode=DPMencode, c │
│    460 │   │   alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), sel │
│    461 │   │   x0 = (x - sigma_t * noise) / alpha_t                          │
│    462 │   │   if self.correcting_x0_fn is not None:                         │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:453 in     │
│ noise_prediction_fn                                                          │
│                                                                              │
│    450 │   │   """                                                           │
│    451 │   │   Return the noise prediction model.                            │
│    452 │   │   """                                                           │
│ ❱  453 │   │   return self.model(x, t, DPMencode=DPMencode, controller=contr │
│    454 │                                                                     │
│    455 │   def data_prediction_fn(self, x, t, DPMencode=False, controller=No │
│    456 │   │   """                                                           │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:347 in     │
│ model_fn                                                                     │
│                                                                              │
│    344 │   │   │   │   t_in = torch.cat([t_continuous] * 2)                  │
│    345 │   │   │   │                                                         │
│    346 │   │   │   │   if ref_init == None:                                  │
│ ❱  347 │   │   │   │   │   noise_uncond, noise = noise_pred_fn(x_in, t_in, c │
│    348 │   │   │   │   else:                                                 │
│    349 │   │   │   │   │   noise_uncond, noise, _, _ = noise_pred_fn(x_in, t │
│    350                                                                       │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:295 in     │
│ noise_pred_fn                                                                │
│                                                                              │
│    292 │   │   if cond is None:                                              │
│    293 │   │   │   output = model(x, t_input, **model_kwargs)                │
│    294 │   │   else:                                                         │
│ ❱  295 │   │   │   output = model(x, t_input, cond, DPMencode, controller=co │
│    296 │   │   if model_type == "noise":                                     │
│    297 │   │   │   return output                                             │
│    298 │   │   elif model_type == "x_start":                                 │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/models/diffusion/dpm_solver/sampler. │
│ py:118 in <lambda>                                                           │
│                                                                              │
│   115 │   │   else:                                                          │
│   116 │   │   │   # x_T is a list                                            │
│   117 │   │   │   model_fn_decode = model_wrapper(                           │
│ ❱ 118 │   │   │   │   lambda x, t, c, DPMencode, controller, inject: self.mo │
│   119 │   │   │   │   ns,                                                    │
│   120 │   │   │   │   model_type=MODEL_TYPES[self.model.parameterization],   │
│   121 │   │   │   │   guidance_type="classifier-free",                       │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/models/diffusion/ddpm.py:859 in      │
│ apply_model                                                                  │
│                                                                              │
│    856 │   │   │   key = 'c_concat' if self.model.conditioning_key == 'conca │
│    857 │   │   │   cond = {key: cond}                                        │
│    858 │   │                                                                 │
│ ❱  859 │   │   x_recon = self.model(x_noisy, t, **cond, encode=encode, encod │
│    860 │   │                                                                 │
│    861 │   │   if isinstance(x_recon, tuple) and not return_ids:             │
│    862 │   │   │   return x_recon[0]                                         │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/models/diffusion/ddpm.py:1330 in     │
│ forward                                                                      │
│                                                                              │
│   1327 │   │   │   │   cc = torch.cat(c_crossattn, 1)                        │
│   1328 │   │   │   else:                                                     │
│   1329 │   │   │   │   cc = c_crossattn                                      │
│ ❱ 1330 │   │   │   out = self.diffusion_model(x, t, context=cc, encode=encod │
│   1331 │   │   elif self.conditioning_key == 'hybrid':                       │
│   1332 │   │   │   xc = torch.cat([x] + c_concat, dim=1)                     │
│   1333 │   │   │   cc = torch.cat(c_crossattn, 1)                            │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/diffusionmodules/openaimodel │
│ .py:796 in forward                                                           │
│                                                                              │
│   793 │   │   layernum = 0                                                   │
│   794 │   │   for module in self.output_blocks:                              │
│   795 │   │   │   h = th.cat([h, hs.pop()], dim=1)                           │
│ ❱ 796 │   │   │   h, layernum = module(h, emb, context, encode=encode, encod │
│   797 │   │   │   # print(layernum)                                          │
│   798 │   │                                                                  │
│   799 │   │   h = h.type(x.dtype)                                            │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/diffusionmodules/openaimodel │
│ .py:89 in forward                                                            │
│                                                                              │
│    86 │   │   │   │   else:                                                  │
│    87 │   │   │   │   │   x = layer(x, emb)                                  │
│    88 │   │   │   elif isinstance(layer, SpatialTransformer):                │
│ ❱  89 │   │   │   │   x, layernum = layer(x, context, encode=encode, encode_ │
│    90 │   │   │   │   │   │   │   │   │   controller=controller, inject=inje │
│    91 │   │   │   else:                                                      │
│    92 │   │   │   │   x = layer(x)                                           │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/attention.py:366 in forward  │
│                                                                              │
│   363 │   │   if self.use_linear:                                            │
│   364 │   │   │   x = self.proj_in(x)                                        │
│   365 │   │   for i, block in enumerate(self.transformer_blocks):            │
│ ❱ 366 │   │   │   x = block(x, context=context[i], encode=encode, encode_unc │
│   367 │   │   │   │   │     controller=controller, inject=inject, layernum=l │
│   368 │   │   if self.use_linear:                                            │
│   369 │   │   │   x = self.proj_out(x)                                       │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/attention.py:280 in forward  │
│                                                                              │
│   277 │   │   self.checkpoint = checkpoint                                   │
│   278 │                                                                      │
│   279 │   def forward(self, x, context=None, encode=False, encode_uncon=Fals │
│ ❱ 280 │   │   return checkpoint(self._forward, (x, context, encode, encode_u │
│   281 │                                                                      │
│   282 │   def _forward(self, x, context=None, encode=False, encode_uncon=Fal │
│   283                                                                        │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/diffusionmodules/util.py:117 │
│ in checkpoint                                                                │
│                                                                              │
│   114 │   """                                                                │
│   115 │   if flag:                                                           │
│   116 │   │   args = tuple(inputs) + tuple(params)                           │
│ ❱ 117 │   │   return CheckpointFunction.apply(func, len(inputs), *args)      │
│   118 │   else:                                                              │
│   119 │   │   return func(*inputs)                                           │
│   120                                                                        │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/diffusionmodules/util.py:132 │
│ in forward                                                                   │
│                                                                              │
│   129 │   │   │   │   │   │   │   │      "dtype": torch.get_autocast_gpu_dty │
│   130 │   │   │   │   │   │   │   │      "cache_enabled": torch.is_autocast_ │
│   131 │   │   with torch.no_grad():                                          │
│ ❱ 132 │   │   │   output_tensors = ctx.run_function(*ctx.input_tensors)      │
│   133 │   │   return output_tensors                                          │
│   134 │                                                                      │
│   135 │   @staticmethod                                                      │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/attention.py:301 in _forward │
│                                                                              │
│   298 │   │   │   │   x = self.attn2(self.norm2(x), context=context) + x     │
│   299 │   │                                                                  │
│   300 │   │   elif encode_uncon == False and decode_uncon == False:          │
│ ❱ 301 │   │   │   x = self.attn1(self.norm1(x), context=context if self.disa │
│   302 │   │   │   │   │   │      controller_for_inject=controller, inject=in │
│   303 │   │   │   x = self.attn2(self.norm2(x), context=context, encode=enco │
│   304 │   │   │   # pass                                                     │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ptp_scripts/ptp_utils.py:375 in forward  │
│                                                                              │
│   372 │   │   │   │                                                          │
│   373 │   │   │   │   del orig_mask, mask_for_realSA, orig_loc_masked, orig_ │
│   374 │   │   │                                                              │
│ ❱ 375 │   │   │   sim = sim.softmax(dim=-1)                                  │
│   376 │   │   │                                                              │
│   377 │   │   │   out = einsum('b i j, b j d -> b i d', sim, v)              │
│   378 │   │   │   out = rearrange(out, '(b h) n d -> b n (h d)', h=h)        │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: CUDA out of memory. Tried to allocate 640.00 MiB (GPU 2; 23.70 GiB
total capacity; 18.52 GiB already allocated; 434.69 MiB free; 19.23 GiB reserved
in total by PyTorch) If reserved memory is >> allocated memory try setting 
max_split_size_mb to avoid fragmentation.  See documentation for Memory 
Management and PYTORCH_CUDA_ALLOC_CONF

Process finished with exit code 1

Gradio APP build invitation

Hi, thank you very much for your open source work! I am an employee of the Shanghai Artificial Intelligence Lab. I would like to invite you to build the Gradio Demo of TF-ICON on our OpenXLab. We can provide you with GPU inference and operation support for free. Looking forward to your reply!

ask the restriction for mask

Hello~thank you for your great job, and I have one question: what's the restriction for the mask? I often get<RuntimeError: The size of tensor a must match the size of tensor b at non-singleton dimension 2>when compositing in pixel space. Thank you again!

Is it possible to composite multiple images onto one background image?

First of all, thank you for providing such a great code.
I would like to use this model to composite various chicken images onto one background image.
For example, using chicken images available on GitHub, I would like to composite chicken images with slight variations onto multiple plates in a background image.
In that case, I think I will need chicken masks, original chicken images, a background image, and a mask image for the background.
I would greatly appreciate it if you could provide an answer!

Some questions about the same domain.

Hi, this is a great job! But I ran into some problems. Is there any requirement for the image domain in the same domain mode? Can it only be photograph? Or can it only be the four image domains in the paper, the photorealism domain, oil painting domains, sketchy painting domains and cartoon domains?

Figure 1 is the image I want to synthesize and Figure 2 is the result of the code synthesis. If I want to synthesize children's picture-book style images, what should I do? By the way, these styles of images I used were generated using Stable Diffusion and Lora. Looking forward to your reply!

cp_bg_fg

00000_a cartoon animation of a girl running in the street

Weird Error

When running main.py as instructed, a quite weird error occurs at

inv_emb = model.get_learned_conditioning(prompts, inv)

The error tells like:

TypeError: get_learned_conditioning() takes 2 positional arguments but 3 were given

the function 'get_learned_conditioning' in ddpm obviously takes in two argument, here prompts is ['a pencil drawing of an eiffel tower in the distance, black and white painting'] and inv = True, i can't see anywhere wrong.
Might it caused by the version of stable diffusion (like function might change in different models)?i used the same pretrained stable diffusion as link refers to.

Foreground mask

I notice the code need Foreground mask, but never mentioned in the paper, would you please tell me the reason?

xl

can this work using xl?

About custom masks

Excellent work: "TF-ICON: Diffusion-Based Training-Free Cross-Domain Image Composition",
TF-ICON has made significant advancements in terms of training-free image composition, which is truly impressive.

Similar to the examples provided on GitHub, such as fg_mask.png and mask_bg_fg.jpg, I am keen to understand a convenient method to generate these masks effortlessly for custom images. Could you please share insights or suggestions on achieving this?

绘图

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.