Skip to content

Commit

Permalink
Merge pull request ggerganov#351 from player1537-forks/th/add-logits-…
Browse files Browse the repository at this point in the history
…bias-parameter

Add support for `logit_bias` and `logit_bias_type` parameters
  • Loading branch information
abetlen authored Jun 15, 2023
2 parents abf6d4a + eb7645b commit f568bae
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
2 changes: 2 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,7 @@ def create_chat_completion(
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages.
Expand Down Expand Up @@ -1419,6 +1420,7 @@ def create_chat_completion(
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
Expand Down
53 changes: 51 additions & 2 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,14 @@ class CreateCompletionRequest(BaseModel):
)
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)

# ignored or currently unsupported
model: Optional[str] = model_field
n: Optional[int] = 1
logprobs: Optional[int] = Field(None)
best_of: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)

# llama.cpp specific parameters
Expand All @@ -280,6 +281,39 @@ class Config:
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)


def make_logit_bias_processor(
llama: llama_cpp.Llama,
logit_bias: Dict[str, float],
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
):
if logit_bias_type is None:
logit_bias_type = "input_ids"

to_bias: Dict[int, float] = {}
if logit_bias_type == "input_ids":
for input_id, score in logit_bias.items():
input_id = int(input_id)
to_bias[input_id] = score

elif logit_bias_type == "tokens":
for token, score in logit_bias.items():
token = token.encode('utf-8')
for input_id in llama.tokenize(token, add_bos=False):
to_bias[input_id] = score

def logit_bias_processor(
input_ids: List[int],
scores: List[float],
) -> List[float]:
new_scores = [None] * len(scores)
for input_id, score in enumerate(scores):
new_scores[input_id] = score + to_bias.get(input_id, 0.0)

return new_scores

return logit_bias_processor


@router.post(
"/v1/completions",
response_model=CreateCompletionResponse,
Expand All @@ -297,9 +331,16 @@ async def create_completion(
"n",
"best_of",
"logit_bias",
"logit_bias_type",
"user",
}
kwargs = body.dict(exclude=exclude)

if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
])

if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10)

Expand Down Expand Up @@ -378,11 +419,12 @@ class CreateChatCompletionRequest(BaseModel):
stream: bool = stream_field
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)

# ignored or currently unsupported
model: Optional[str] = model_field
n: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)

# llama.cpp specific parameters
Expand Down Expand Up @@ -419,9 +461,16 @@ async def create_chat_completion(
exclude = {
"n",
"logit_bias",
"logit_bias_type",
"user",
}
kwargs = body.dict(exclude=exclude)

if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
])

if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10)

Expand Down

0 comments on commit f568bae

Please sign in to comment.