Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference backend support system prompt #3313

Merged
merged 5 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ async def create_assistant_message(
work_parameters = inference.WorkParameters(
model_config=model_config,
sampling_parameters=request.sampling_parameters,
system_prompt=request.system_prompt,
plugins=request.plugins,
plugin_max_depth=settings.plugin_max_depth,
)
Expand Down
1 change: 1 addition & 0 deletions inference/server/oasst_inference_server/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class CreateAssistantMessageRequest(pydantic.BaseModel):
parent_id: str
model_config_name: str
sampling_parameters: inference.SamplingParameters = pydantic.Field(default_factory=inference.SamplingParameters)
system_prompt: str | None = None
plugins: list[inference.PluginEntry] = pydantic.Field(default_factory=list[inference.PluginEntry])
used_plugin: inference.PluginUsed | None = None

Expand Down
1 change: 1 addition & 0 deletions inference/worker/chat_chain_prompts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
V2_ASST_PREFIX = "<|assistant|>"
V2_PROMPTER_PREFIX = "<|prompter|>"
V2_SYSTEM_PREFIX = "<|system|>"

ASSISTANT_PREFIX = "Open Assistant"
HUMAN_PREFIX = "Human"
Expand Down
19 changes: 17 additions & 2 deletions inference/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sseclient
import transformers
import websocket
from chat_chain_prompts import V2_PROMPTER_PREFIX
from chat_chain_prompts import V2_PROMPTER_PREFIX, V2_SYSTEM_PREFIX
from loguru import logger
from oasst_shared.schemas import inference
from settings import settings
Expand Down Expand Up @@ -80,12 +80,23 @@ def truncate_prompt(
):
with shared_tokenizer_lock:
ids = tokenizer.encode(prompt)
prompter_prefix_id = tokenizer.convert_tokens_to_ids(V2_PROMPTER_PREFIX)

system_prompt: str | None = None
system_tokens: list[int] | None = None
if prompt.startswith(V2_SYSTEM_PREFIX):
system_prompt = prompt[: prompt.index(V2_PROMPTER_PREFIX)]
system_tokens = ids[: ids.index(prompter_prefix_id)]

max_input_length = get_max_input_length(worker_config, plugin_used)

if len(ids) > max_input_length:
logger.debug(f"Prompt too long, left-truncating to {max_input_length} tokens")
ids = ids[-(max_input_length - 1) :]

num_system_tokens = len(system_tokens) if system_tokens else 0
# Maximum token allowed for the conversation, ex system prompt
max_conversation_length = max_input_length - num_system_tokens
ids = ids[-(max_conversation_length - 1) :]

with shared_tokenizer_lock:
prompt = tokenizer.decode(ids)
Expand All @@ -94,6 +105,10 @@ def truncate_prompt(
prompt = V2_PROMPTER_PREFIX + prompt
ids = tokenizer.encode(V2_PROMPTER_PREFIX) + ids

if system_tokens:
prompt = system_prompt + prompt
ids = system_tokens + ids

max_total_tokens = worker_config.model_config.max_total_length
input_length = len(ids)
spare = max_total_tokens - input_length - 1
Expand Down
6 changes: 6 additions & 0 deletions inference/worker/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
THOUGHT_SEQ,
V2_ASST_PREFIX,
V2_PROMPTER_PREFIX,
V2_SYSTEM_PREFIX,
)
from loguru import logger
from oasst_shared.schemas import inference
Expand All @@ -38,13 +39,18 @@ def _prepare_message(message: inference.MessageRead) -> str:
# construct prompt
messages = [_prepare_message(message) for message in work_request.thread.messages]

if work_request.parameters.system_prompt:
pre_prompt = V2_SYSTEM_PREFIX + work_request.parameters.system_prompt + eos_token
messages = [pre_prompt] + messages

prompt = "".join(messages) + V2_ASST_PREFIX

parameters = interface.GenerateStreamParameters.from_work_parameters(work_request.parameters)
if settings.use_stop_sequences:
parameters.stop = [
V2_PROMPTER_PREFIX,
V2_ASST_PREFIX,
V2_SYSTEM_PREFIX,
]
if eos_token:
parameters.stop.append(eos_token)
Expand Down
1 change: 1 addition & 0 deletions oasst-shared/oasst_shared/schemas/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class WorkParameters(pydantic.BaseModel):
seed: int = pydantic.Field(
default_factory=make_seed,
)
system_prompt: str | None = None
plugins: list[PluginEntry] = pydantic.Field(default_factory=list[PluginEntry])
plugin_max_depth: int = 4

Expand Down