Coder Social home page Coder Social logo

yk7333 / d3po Goto Github PK

View Code? Open in Web Editor NEW
141.0 141.0 11.0 14.84 MB

[CVPR 2024] Code for the paper "Using Human Feedback to Fine-tune Diffusion Models without Any Reward Model"

Home Page: https://arxiv.org/abs/2311.13231

License: MIT License

Python 100.00%
diffusion-models human-feedback reinforcement-learning

d3po's Introduction

  • πŸ‘‹ Hi, I’m Kai Yang(杨恺)
  • πŸ‘€ I’m interested in AI
  • 🌱 I’m currently learning Reinforcement Learning
  • πŸ’žοΈ I’m a master's student currently studying at THU SIGS
  • πŸ“« You can contact me through [email protected]

d3po's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

d3po's Issues

DPO with existed-images

Hello! I'm interested in your work, and I want to do some work based on yours.
Now I have existed generated-images ,they are divided into prefered and unprefered. How can I train diffusion model with the existed images directly in your framework? thanks!

Releasing human preference data

Hey,
Thanks for your great work and for releasing your code publicly. I am very interested in the human evaluation data that you show in Fig 7 of the paper. Is it possible for you to release the exact prompts, image-pairs and raw human preferences that you collected for plotting Fig 7? This would be super helpful, thanks!

ValueError : Attemting to unscale fp16 Gradients

Fine-tuning the u-net with LoRA disabled and fp16 AMP will trigger a β€œ[ValueError : Attemting to unscale fp16 Gradients]” error.
Obviously, this is caused by the dtype of the parameters and gradients in the u-net module being fp16, and the scaler failed to scale those gradients.
This could be fixed with:
[train_with_rm.py]
pipeline.vae.to(accelerator.device, dtype=inference_dtype)
pipeline.text_encoder.to(accelerator.device, dtype=inference_dtype)
[NEW] -> unet_dtype = inference_dtype if config.use_lora else torch.float32
pipeline.unet.to(accelerator.device, dtype=unet_dtype)
πŸ€” Is there any better solutions?

Online or Offline?

Is the algorithm offline or online? What I wonder is the training dataset is fixed or generated before each epoch.

Reproducing Aesthetic Quality results from paper

I'm trying to reproduce aesthetic quality results from paper on simple animals prompts on a 4 GPU node.
Sampling and training params have been kept the same apart from learning rate at 3e-5 (taken from paper appendix).

Results seem to show disintegration
Screenshot 2023-12-14 at 8 48 47 AM
Screenshot 2023-12-14 at 8 49 11 AM

A few questions I had

  1. Any more adjustments to be made in params?
  2. Is d3po in general very sensitive to parameters and is that data specific?

Pretrained weight release

First, thank you for introducing nice works and releasing training code.
For easier further experiment, do you have plan to release pretrained weight?

Combine with dreambooth

I wonder whether this method can be used to improve the result of dreambooth finetuning? such as face fidelity or modify the plastic skin of personalized character in dreambooth result to make it more realistic something like that?

assert num_timesteps == config.sample.num_steps AssertionError

Thank you for your excellent work!
I attempted to perform RLHF on anything-v4.5's human body defects based on the dataset and prompts you provided publicly. However, an error occurred during training, same to the title.
The training details conducted in runpod are provided below.
Additionally, I upload base.py and train.py to https://huggingface.co/datasets/sdtana/anything-v4.5_dpo.
As a non-specialist, I apologize for reaching out to you directly without undergoing extensive testing.
Could you provide advice on where I made a mistake?

I1202 14:25:55.659013 140285781463488 logging.py:60]
allow_tf32: true
logdir: logs
mixed_precision: fp16
num_checkpoint_limit: 10
num_epochs: 5
pretrained:
model: /workspace/anything-v4.5
revision: main
prompt_fn: anything_prompt
prompt_fn_kwargs: {}
resume_from: ''
reward_fn: jpeg_compressibility
run_name: anythingdpo_2023.12.02_14.25.53
sample:
batch_size: 2
eta: 1.0
guidance_scale: 5.0
num_batches_per_epoch: 100
num_steps: 20
save_interval: 100
save_freq: 1
seed: 42
train:
adam_beta1: 0.9
adam_beta2: 0.999
adam_epsilon: 1.0e-08
adam_weight_decay: 0.0001
adv_clip_max: 5
batch_size: 1
beta: 0.1
cfg: true
gradient_accumulation_steps: 1
json_path: /workspace/data/epoch1/json
learning_rate: 1.0e-05
max_grad_norm: 1.0
num_inner_epochs: 1
sample_path: /workspace/data/epoch1
save_interval: 50
timestep_fraction: 1.0
use_8bit_adam: false
use_lora: true

text_config_dict is provided which will be used to initialize CLIPTextConfig. The value text_config["id2label"] will be overriden.
I1202 14:26:06.847579 140285781463488 logging.py:60] ***** Running training *****
I1202 14:26:06.847816 140285781463488 logging.py:60] Num Epochs = 5
I1202 14:26:06.847866 140285781463488 logging.py:60] Sample batch size per device = 2
I1202 14:26:06.847907 140285781463488 logging.py:60] Train batch size per device = 1
I1202 14:26:06.847945 140285781463488 logging.py:60] Gradient Accumulation steps = 1
I1202 14:26:06.847987 140285781463488 logging.py:60]
I1202 14:26:06.848022 140285781463488 logging.py:60] Total number of samples per epoch = 200
I1202 14:26:06.848079 140285781463488 logging.py:60] Total train batch size (w. parallel, distributed & accumulation) = 1
I1202 14:26:06.848118 140285781463488 logging.py:60] Number of gradient updates per inner epoch = 200
I1202 14:26:06.848156 140285781463488 logging.py:60] Number of inner epochs = 1
I1202 15:19:19.564940 140285781463488 logging.py:60] Saving current state to logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_0
I1202 15:19:19.665072 140285781463488 logging.py:60] Optimizer state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_0/optimizer.bin | 48/200 [53:07<2:30:54, 59.57s/it]
I1202 15:19:19.670379 140285781463488 logging.py:60] Gradient scaler state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_0/scaler.pt
I1202 15:19:19.682694 140285781463488 logging.py:60] Random states saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_0/random_states_0.pkl
I1202 16:18:57.735666 140285781463488 logging.py:60] Saving current state to logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_1
I1202 16:18:57.821295 140285781463488 logging.py:60] Optimizer state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_1/optimizer.bin | 99/200 [1:52:45<2:39:23, 94.68s/it]
I1202 16:18:57.827801 140285781463488 logging.py:60] Gradient scaler state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_1/scaler.pt
I1202 16:18:57.833970 140285781463488 logging.py:60] Random states saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_1/random_states_0.pkl
I1202 17:20:06.462761 140285781463488 logging.py:60] Saving current state to logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_2
I1202 17:20:06.565670 140285781463488 logging.py:60] Optimizer state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_2/optimizer.binβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 148/200 [2:53:54<1:34:02, 108.51s/it]
I1202 17:20:06.572802 140285781463488 logging.py:60] Gradient scaler state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_2/scaler.pt
I1202 17:20:06.578633 140285781463488 logging.py:60] Random states saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_2/random_states_0.pkl
I1202 18:30:18.542086 140285781463488 logging.py:60] Saving current state to logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_3
I1202 18:30:18.639742 140285781463488 logging.py:60] Optimizer state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_3/optimizer.binβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 198/200 [4:04:06<01:59, 59.57s/it]
I1202 18:30:18.645961 140285781463488 logging.py:60] Gradient scaler state saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_3/scaler.pt
I1202 18:30:18.650440 140285781463488 logging.py:60] Random states saved in logs/anythingdpo_2023.12.02_14.25.53/checkpoints/checkpoint_3/random_states_0.pkl
Traceback (most recent call last):
File "/workspace/d3po/scripts/train.py", line 424, in
app.run(main)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/workspace/d3po/scripts/train.py", line 266, in main
assert num_timesteps == config.sample.num_steps
AssertionError

LICENSE

Thanks a lot for sharing your code! Could you also add a LICENSE file to your repo such that usage conditions and restrictions are clear?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.