Coder Social home page Coder Social logo

attention-mask-control's Introduction

Code for paper: "Compositional Text-to-Image Synthesis with Attention Map Control of Diffusion Models"

[Projext Page][Paper]

Requirements

A suitable conda environment named AMC can be created and activated with:

conda env create -f environment.yaml
conda activate AMC

Data Prepearing

First, please download the coco dataset from here. We use COCO2014 in the paper. Then, you can process your data with this script:

python coco_preprocess.py \
    --coco_image_path /YOUR/COCO/PATH/train2014 \
    --coco_caption_file /YOUR/COCO/PATH/annotations/captions_train2014.json \
    --coco_instance_file /YOUR/COCO/PATH/annotations/instances_train2014.json \
    --output_dir /YOUR/DATA/PATH

Training

Before training, you need to change configs in train_boxnet.sh

  • ROOT_DIR: where to save all the results.
  • webdataset_base_urls: /YOUR/DATA/PATH/{xxx-xxx}.tar
  • model_path: stable diffusion v1-5 checkpoint

You can train the BoxNet through this script:

sh train_boxnet.sh $NODE_NUM $CURRENT_NODE_RANK $GPUS_PER_NODE

Text-to-Image Synthesis

With a trained BoxNet, you can start the Text-to-Image Synthesis with:

python test_pipeline_onestage.py \
	--stable_model_path /stable-diffusion-v1-5/checkpoint
	--boxnet_model_path /TRAINED/BOXNET/CKPT
	--output_dir /YOUR/SAVE/DIR

all the test prompt is saved in file test_prompts.json.

TODOs

  • Release data preparation code
  • Release inference code
  • Release training code
  • Release demo
  • Release checkpoint

Acknowledgements

This implementation is based on the repo from the diffusers library. Fengshenbang-LM codebase. DETR codebase.

attention-mask-control's People

Contributors

1073521013 avatar wrch1994 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

attention-mask-control's Issues

AttributeError: 'SetCriterion' object has no attribute 'loss_masks'. Did you mean: 'loss_labels'?

Thanks for your work!

I tried to execute train_boxnet.sh and there were some issues with non-existing attribute 'loss_masks'.
Specifically, in boxnet_models/boxnet.py, line 199, in get_loss

def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes,
            'masks': self.loss_masks
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

Looking forward for your reply, Thank you.

RuntimeError: CUDA error: an illegal instruction was encountered

Hello, I'm very interested in your work. I train boxnet successfully with python 3.8 , torch 2.0.1 and cuda11.7. And then I want to finetune the unet, so I set '--train_unet' True and train on the same devices, but I get RuntimeError: CUDA error: an illegal instruction was encountered. How can I train the unet ? Thank you.

Traceback (most recent call last):
File "train_boxnet.py", line 619, in
trainer.fit(model, datamoule, ckpt_path=args.load_ckpt_path)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
call._call_and_handle_interrupt(
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 63, in _call_and_handle_interrupt
trainer._teardown()
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _teardown
self.strategy.teardown()
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 490, in teardown
super().teardown()
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/strategies/parallel.py", line 125, in teardown
super().teardown()
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 492, in teardown
_optimizers_to_device(self.optimizers, torch.device("cpu"))
File "/opt/conda/lib/python3.8/site-packages/lightning_fabric/utilities/optimizer.py", line 28, in _optimizers_to_device
_optimizer_to_device(opt, device)
File "/opt/conda/lib/python3.8/site-packages/lightning_fabric/utilities/optimizer.py", line 34, in _optimizer_to_device
optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device)
File "/opt/conda/lib/python3.8/site-packages/lightning_utilities/core/apply_func.py", line 59, in apply_to_collection
v = apply_to_collection(
File "/opt/conda/lib/python3.8/site-packages/lightning_utilities/core/apply_func.py", line 51, in apply_to_collection
return function(data, *args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/lightning_fabric/utilities/apply_func.py", line 101, in move_data_to_device
return apply_to_collection(batch, dtype=_TransferableDataType, function=batch_to)
File "/opt/conda/lib/python3.8/site-packages/lightning_utilities/core/apply_func.py", line 51, in apply_to_collection
return function(data, *args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/lightning_fabric/utilities/apply_func.py", line 95, in batch_to
data_output = data.to(device, **kwargs)
RuntimeError: CUDA error: an illegal instruction was encountered

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.