Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Streamed tool calls now more strictly follow OpenAI's format; ensures Vercel AI SDK compatibility #8272

Merged
merged 8 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix: streaming chunks
  • Loading branch information
Kyle Mistele committed Sep 7, 2024
commit 3bb57ccb9766b48a88fc5e295d73364f12986bb2
10 changes: 8 additions & 2 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,17 @@ async def chat_completion_stream_generator(
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):

# TODO - this breaks, it needs to include ALL fields
K-Mistele marked this conversation as resolved.
Show resolved Hide resolved
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role),
delta=DeltaMessage(
role=role,
content="",
),
logprobs=None,
finish_reason=None)
finish_reason=None
)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(self, tokenizer: AnyTokenizer):
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.current_tool_initial_sent: bool = False
self.streamed_args_for_tool: List[str] = []

self.model_tokenizer = tokenizer
Expand Down
48 changes: 19 additions & 29 deletions vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
InitialDeltaToolCall, ToolCall)
ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
from vllm.entrypoints.openai.tool_parsers.utils import (
Expand All @@ -34,7 +34,6 @@ def __init__(self, tokenizer: AnyTokenizer):
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent = False
self.current_tool_initial_sent: bool = False
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list

Expand Down Expand Up @@ -95,7 +94,7 @@ def extract_tool_calls(self,
]

content = model_output[:model_output.
find(self.tool_call_start_token)]
find(self.tool_call_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
Expand All @@ -109,13 +108,13 @@ def extract_tool_calls(self,
content=model_output)

def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:

logger.debug("delta_text: %s", delta_text)
Expand Down Expand Up @@ -168,7 +167,6 @@ def extract_tool_calls_streaming(
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.current_tool_initial_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)

Expand Down Expand Up @@ -198,7 +196,7 @@ def extract_tool_calls_streaming(
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
exclude_none=True))
])

# case -- otherwise we're just generating text
Expand All @@ -218,27 +216,19 @@ def extract_tool_calls_streaming(
logger.debug('not enough tokens to parse into JSON yet')
return None

# case - we haven't sent the initial delta with the tool call ID
# (it will be sent)
if not self.current_tool_initial_sent:
self.current_tool_initial_sent = True
return DeltaMessage(tool_calls=[
InitialDeltaToolCall(
index=self.current_tool_id).model_dump(
exclude_none=True)
])

# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
elif not self.current_tool_name_sent:
if not self.current_tool_name_sent:
function_name: Union[str, None] = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True)
)
])
else:
return None
Expand Down Expand Up @@ -305,7 +295,7 @@ def extract_tool_calls_streaming(
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).model_dump(
exclude_none=True))
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= arguments_delta
Expand All @@ -324,7 +314,7 @@ def extract_tool_calls_streaming(
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= argument_diff
Expand Down
12 changes: 1 addition & 11 deletions vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,9 @@ def extract_tool_calls_streaming(

# case: update an existing tool - this is handled below

# if the current tool initial data incl. the id, type=function
# and idx not sent, send that
if not self.current_tool_initial_sent:
self.current_tool_initial_sent = True
delta = DeltaMessage(tool_calls=[
InitialDeltaToolCall(
index=self.current_tool_id).model_dump(
exclude_none=True)
])

# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:

Expand Down