Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mistral] Mistral-7B-v0.1 support #1196

Merged
merged 7 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sentencepiece # Required for LLaMA tokenizer.
numpy
torch >= 2.0.0
transformers >= 4.33.1 # Required for Code Llama.
xformers >= 0.0.21
xformers >= 0.0.22
fastapi
uvicorn
pydantic < 2 # Required for OpenAI server.
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,12 @@ def __init__(
block_size: int,
gpu_memory_utilization: float,
swap_space: int,
sliding_window: Optional[int] = None,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.sliding_window = sliding_window
self._verify_args()

# Will be set after profiling.
Expand Down
41 changes: 30 additions & 11 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,18 @@ def __init__(
num_gpu_blocks: int,
num_cpu_blocks: int,
watermark: float = 0.01,
sliding_window: Optional[int] = None,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks

self.block_sliding_window = None
if sliding_window is not None:
assert sliding_window % block_size == 0, (sliding_window,
block_size)
self.block_sliding_window = sliding_window // block_size

self.watermark = watermark
assert watermark >= 0.0

Expand All @@ -83,6 +91,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> bool:
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs()[0]
num_required_blocks = len(seq.logical_token_blocks)
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction.
return (num_free_gpu_blocks - num_required_blocks >=
Expand All @@ -95,8 +106,12 @@ def allocate(self, seq_group: SequenceGroup) -> None:

# Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = []
for _ in range(len(seq.logical_token_blocks)):
block = self.gpu_allocator.allocate()
for logical_idx in range(len(seq.logical_token_blocks)):
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window]
else:
block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
block_table.append(block)
Expand All @@ -118,11 +133,17 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
block_table = self.block_tables[seq.seq_id]

if len(block_table) < len(logical_blocks):
# The sequence has a new logical block.
# Allocate a new physical block.
block = self.gpu_allocator.allocate()
block_table.append(block)
return None
if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
# re-use a block
block_table.append(block_table[len(block_table) %
self.block_sliding_window])
else:
# The sequence has a new logical block.
# Allocate a new physical block.
block = self.gpu_allocator.allocate()
block_table.append(block)
return None

# We want to append the token to the last physical block.
last_block = block_table[-1]
Expand Down Expand Up @@ -154,9 +175,7 @@ def _get_physical_blocks(
for seq in seq_group.get_seqs():
if seq.is_finished():
continue
block_table = self.block_tables[seq.seq_id]
for block in block_table:
blocks.add(block)
blocks.update(self.block_tables[seq.seq_id])
return list(blocks)

def can_swap_in(self, seq_group: SequenceGroup) -> bool:
Expand Down Expand Up @@ -224,7 +243,7 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
return block_number_mapping

def _free_block_table(self, block_table: BlockTable) -> None:
for block in block_table:
for block in set(block_table):
if block.device == Device.GPU:
self.gpu_allocator.free(block)
else:
Expand Down
2 changes: 1 addition & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks,
)
sliding_window=self.cache_config.sliding_window)

# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state.
Expand Down
6 changes: 3 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ def create_engine_configs(
self.download_dir, self.load_format,
self.dtype, self.seed, self.revision,
self.max_model_len, self.quantization)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space)
cache_config = CacheConfig(
self.block_size, self.gpu_memory_utilization, self.swap_space,
getattr(model_config.hf_config, 'sliding_window', None))
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray)
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(

self.model_config = model_config
self.cache_config = cache_config
assert self.cache_config.sliding_window == getattr(
self.model_config.hf_config, "sliding_window", None)
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
Expand Down Expand Up @@ -660,7 +662,7 @@ def _check_stop(self, seq: Sequence,
return

# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
if seq.get_output_len() >= sampling_params.max_tokens:
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

Expand Down
21 changes: 20 additions & 1 deletion vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import torch
from xformers.ops import AttentionBias
Expand Down Expand Up @@ -29,6 +29,7 @@ def __init__(
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
sliding_window: Optional[int] = None,
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
Expand All @@ -38,6 +39,24 @@ def __init__(
self.max_context_len = max_context_len
self.block_tables = block_tables

self.to_cache = None
if sliding_window is not None:
# We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which
# elements we need to cache and where
to_cache, start_idx = [], 0
for prompt_len in self.prompt_lens:
to_cache.extend(
range(
start_idx + max(0, prompt_len - sliding_window),
start_idx + prompt_len,
))
start_idx += prompt_len
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
self.to_cache = torch.tensor(to_cache,
dtype=torch.int32,
device=self.slot_mapping.device)

self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0]
Expand Down
31 changes: 24 additions & 7 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from xformers.ops.fmha.attn_bias import (BlockDiagonalMask,
LowerTriangularMaskWithTensorBias)

from vllm import attention_ops
Expand Down Expand Up @@ -58,12 +58,14 @@ def __init__(self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None) -> None:
num_kv_heads: Optional[int] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Expand All @@ -85,7 +87,9 @@ def set_attn_bias(
# Already set by a previous layer.
return
prompt_lens = input_metadata.prompt_lens
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
attn_bias = BlockDiagonalMask.from_seqlens(prompt_lens)
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(self.sliding_window)
input_metadata.attn_bias.append(attn_bias)

def multi_query_kv_attention(
Expand Down Expand Up @@ -223,12 +227,20 @@ def forward(
if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None):
# The stride is 3 because the key and value are sliced from qkv.
key_to_cache = key[:num_valid_tokens]
value_to_cache = value[:num_valid_tokens]
slot_mapping = input_metadata.slot_mapping
if input_metadata.to_cache is not None:
key_to_cache = key_to_cache[input_metadata.to_cache]
value_to_cache = value_to_cache[input_metadata.to_cache]
slot_mapping = slot_mapping[input_metadata.to_cache]

cache_ops.reshape_and_cache(
key[:num_valid_tokens],
value[:num_valid_tokens],
key_to_cache,
value_to_cache,
key_cache,
value_cache,
input_metadata.slot_mapping,
slot_mapping,
)

if input_metadata.num_generation_tokens > 0:
Expand Down Expand Up @@ -262,8 +274,13 @@ def __init__(
num_kv_heads: Optional[int] = None,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads)
super().__init__(num_heads,
head_size,
scale,
num_kv_heads,
sliding_window=sliding_window)
if rope_scaling is None:
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
max_position, base,
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"InternLMForCausalLM": InternLMForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MistralForCausalLM": MistralForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel
from vllm.model_executor.models.mistral import MistralForCausalLM

__all__ = [
"AquilaForCausalLM",
Expand All @@ -28,4 +29,5 @@
"MPTForCausalLM",
"OPTForCausalLM",
"QWenLMHeadModel",
"MistralForCausalLM",
]
Loading
Loading