Skip to content

Commit

Permalink
Automatically configure max_num_batched_tokens (vllm-project#1198)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Sep 27, 2023
1 parent 28e616c commit a19bc5c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
43 changes: 34 additions & 9 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,36 @@ class SchedulerConfig:
and generated text).
"""

def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
max_model_len: int) -> None:
self.max_num_batched_tokens = max_num_batched_tokens
def __init__(
self,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
# If max_model_len is too short, use 2048 as the default value for
# higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self._verify_args()

def _verify_args(self) -> None:
if self.max_num_batched_tokens < self.max_model_len:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). "
"This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len.")
if self.max_num_batched_tokens < self.max_num_seqs:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
f"({self.max_num_seqs}).")


_STR_DTYPE_TO_TORCH_DTYPE = {
Expand Down Expand Up @@ -350,14 +375,14 @@ def _get_and_verify_max_len(
max_len_key = getattr(hf_config, key, None)
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)
if derived_max_model_len == float("inf"):
raise ValueError(
"The model's config.json must contain one of the following keys "
"to determine the original maximum length of the model: "
f"{possible_keys}")

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
if derived_max_model_len == float("inf"):
raise ValueError(
"When using rope_scaling, the model's config.json must "
"contain one of the following keys to determine the original "
f"maximum length of the model: {possible_keys}")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
derived_max_model_len *= scaling_factor
Expand All @@ -371,4 +396,4 @@ def _get_and_verify_max_len(
" in model's config.json). This may lead to incorrect model "
"outputs or CUDA errors. Make sure the value is correct and "
"within the model context size.")
return max_model_len
return int(max_model_len)
3 changes: 1 addition & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class EngineArgs:
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: int = 2560
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
disable_log_stats: bool = False
revision: Optional[str] = None
Expand All @@ -34,7 +34,6 @@ class EngineArgs:
def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

@staticmethod
def add_cli_args(
Expand Down

0 comments on commit a19bc5c

Please sign in to comment.