Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][scheduler] simplify and improve scheduler #6867

Merged
merged 9 commits into from
Aug 1, 2024

Conversation

youkaichao
Copy link
Member

fix the problem find in #6865 (comment) .

avoid copy the whole queue in every step. use inplace update of the queue instead.

remove sorting policy, as the queues are always sorted by first-come-first-serve order.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member Author

Combining #6865 and this PR:

$ python benchmarks/benchmark_throughput.py --output-len 16 --input 16 --model facebook/opt-125m --num-prompts 1000

# 1k requests
Throughput: 694.62 requests/s, 22227.94 tokens/s

$ python benchmarks/benchmark_throughput.py --output-len 16 --input 16 --model facebook/opt-125m --num-prompts 10000

# 10k requests
Throughput: 735.86 requests/s, 23547.52 tokens/s

$ python benchmarks/benchmark_throughput.py --output-len 16 --input 16 --model facebook/opt-125m --num-prompts 100000

# 100k requests
Throughput: 656.60 requests/s, 21011.07 tokens/s

$ python benchmarks/benchmark_throughput.py --output-len 16 --input 16 --model facebook/opt-125m --num-prompts 500000

# 500k requests
Throughput: 609.16 requests/s, 19493.24 tokens/s

The throughput is almost constant w.r.t. the number of requests now.

@njhill
Copy link
Member

njhill commented Jul 29, 2024

@youkaichao could we keep the policy abstraction and just not use it in the default FCFS case? I know of research teams that are experimenting with different policies using this. Of course we can document that there may be nontrival cost to using a different policy.

@WoosukKwon WoosukKwon self-assigned this Jul 30, 2024
@youkaichao
Copy link
Member Author

@youkaichao could we keep the policy abstraction and just not use it in the default FCFS case? I know of research teams that are experimenting with different policies using this. Of course we can document that there may be nontrival cost to using a different policy.

I don't want to support this. In my opinion, request prioritization should happen one layer above vLLM, e.g. in load balance layer and cross-instance scheduling layer.

vLLM is an inference engine, it's goal is to give results as quickly as possible. It does not make any sense to consider request priority here.

@njhill
Copy link
Member

njhill commented Jul 30, 2024

@youkaichao that's a valid opinion but I think we should collect others before making a decision on this.

I think this is an important PR/optimization but removal of the Policy abstraction is a separate question that should be addressed in a separate issue/PR imo.

@youkaichao
Copy link
Member Author

Actually the policy is never feature-complete. It just applies to running queue and swapped queue. No policy is applied to the waiting queue.

If you want a full-fledged research framework for request prioritization, I don't think vLLM is a good palce.

For example, if a request joins, with highest priority, do you need to stop or preempt the running requests? When you swap out a request, and swap it in again, do you change it's priority? Is request priority a constant (determined when received), or can be dynamic and changing in every iteration?

With all these unclear questions, and the current naive implementation of policy in vLLM, I don't think it is worth maintaining, especially given that we need to aim for performance now.

@njhill
Copy link
Member

njhill commented Jul 30, 2024

@youkaichao vLLM does already serve as a foundation for many different research activities.

Re prioritization features, see #6077, #5958. @apatke @saurabhjha1 may want to comment further!

@youkaichao
Copy link
Member Author

As discussed in #6077 , per-iteration sorting can be removed. And we will introduce priority interface later. So this PR is good to go now.

@WoosukKwon WoosukKwon self-requested a review August 1, 2024 02:08
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao Thanks for this simplification!

@@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,

# Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size
"num_gpu_blocks_override": 2 * (16 + 1),
"num_gpu_blocks_override": 2 * (16 + 2),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: What is this change for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is some corner case, that swapping order change in this PR makes the number of blocks not enough.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao Oh why the swapping order changes? I thought the change in this PR doesn't actually change the FCFS logic at all.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The running queue is FCFS. The swapping queue is somewhat complicated, when you swap in / swap out again and again (which happens in this test case).

That being said, I don't think we have strict FCFS even before this PR, when we consider swap in / swap out again and again. I think we can only guarantee strict FCFS after we use priority queue in the future.

@youkaichao youkaichao merged commit c8a7e93 into vllm-project:main Aug 1, 2024
28 checks passed
@youkaichao youkaichao deleted the inplace_update_queue branch August 1, 2024 06:53
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants