-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add gits scheduler #3769
feat: add gits scheduler #3769
Conversation
This seems interesting. |
[14.61464119, 2.45070267, 1.32549286, 0.86115354, 0.64427125, 0.50118381, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], | ||
[14.61464119, 2.45070267, 1.36964464, 0.92192322, 0.69515091, 0.54755926, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], | ||
[14.61464119, 2.45070267, 1.41535246, 0.95350921, 0.72133851, 0.57119018, 0.4783645, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], | ||
], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you tell me where these values are from in the reference code or how you calculated them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you tell me where these values are from in the reference code or how you calculated them?
TL;DR
To avoid extra computational overhead for users, in the submitted script, we provide the sigmas pre-calculated on our device (requiring only several minutes using 4 A100 GPUs). Our proposed GITS utilizes the regularity of the sampling trajectories of diffusion models (they all look like a “boomerang” curve) and finds optimal time schedules using dynamic programming with a hyperparameter 'coeff'. Dynamic programming is performed on a cost matrix, where each cost evalutes the difference between an Euler step and the ground-truth prediction from the current position to the next position.
Derivation of GITS
In our ICML'24 paper On the Trajectory Regularity of ODE-based Diffusion Sampling, we illustrate that different sampling trajectories of diffusion models share almost the same shape structure (a “boomerang” curve), which inspires us to search a better time schedule for ODE-based diffusion sampling. Our approach is named as geometry-inspired time scheduling (GITS).
We first define a searching space denoted as
How to obtain these schedules
We run the following command in our repo which takes several minutes using 4 A100 GPUs:
SOLVER_FLAGS="--solver=dpmpp --num_steps=21 --afs=False"
SCHEDULE_FLAGS="--schedule_type=discrete --schedule_rho=1"
ADDITIONAL_FLAGS="--max_order=2 --predict_x0=False --lower_order_final=True"
GUIDANCE_FLAGS="--guidance_type=cfg --guidance_rate=7.5"
GITS_FLAGS="--dp=True --metric=dev --coeff=1.2 --num_steps_tea=101 --solver_tea=dpmpp"
torchrun --standalone --nproc_per_node=4 --master_port=11111 \
sample.py --dataset_name="ms_coco" --batch=16 --seeds="0-15" $SOLVER_FLAGS $SCHEDULE_FLAGS $ADDITIONAL_FLAGS $GUIDANCE_FLAGS $GITS_FLAGS
This command is used to generate GITS schedules for Stable Diffusion v1-5 using DPM-Solver++(2M) for sampling. The GITS_FLAGS
includes the settings for GITS where we set 100-step trajectory as the ground-truth. This command for Stable Diffusion will by default generate all time schedules at one time (including indices and sigmas), for each coefficient ./dp_record.txt
.
To avoid extra computational overhead for users, in the submitted script, we provide the sigmas generated for
I see that the examples all use SD1.5; is SDXL supported? And if so, do you still recommend the default "coeff" value of 1.20? Thanks! <3 |
Thanks for your attention! Due to limited resources, we have not yet conducted experiments on SDXL. We will test our method on SDXL in the next few days. We speculate that the schedule will not differ significantly, and the current schedule for SD1.5 should be applicable to SDXL. The larger the "coeff", the more biased the schedule is towards 0. So according to the schedules provided in nodes_align_your_steps.py, a default value of 1.20 or 1.25 should be fine. |
The geometry-inspired time scheduling (GITS, ICML'24) provides a series of time schedules (with a hyperparameter 'coeff') for fast sampling of diffusion models. The implementation follows that of nodes_align_your_steps.py. For steps <= 20 we provide GITS schedules and for steps > 20 we follow nodes_align_your_steps.py to perform a log-linear interpolation based on the 20-step GITS schedule.
Below is our workflow for comparison of these three schedules: