Skip to content

Commit

Permalink
add provider checking based on model name and provider (#571)
Browse files Browse the repository at this point in the history
  • Loading branch information
PCSwingle authored Apr 22, 2024
1 parent e51fdfb commit 03cc739
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 94 deletions.
17 changes: 10 additions & 7 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from spice.errors import InvalidProviderError, UnknownModelError

from mentat.llm_api_handler import (
TOKEN_COUNT_WARNING,
Expand Down Expand Up @@ -91,9 +92,11 @@ async def count_tokens(
) -> int:
ctx = SESSION_CONTEXT.get()

_messages = await self.get_messages(system_prompt=system_prompt, include_code_message=include_code_message)
model = ctx.config.model
return ctx.llm_api_handler.spice.count_prompt_tokens(_messages, model)
try:
_messages = await self.get_messages(system_prompt=system_prompt, include_code_message=include_code_message)
return ctx.llm_api_handler.spice.count_prompt_tokens(_messages, ctx.config.model, ctx.config.provider)
except (UnknownModelError, InvalidProviderError):
return 0

async def get_messages(
self,
Expand Down Expand Up @@ -126,7 +129,7 @@ async def get_messages(

if include_code_message:
code_message = await ctx.code_context.get_code_message(
ctx.llm_api_handler.spice.count_prompt_tokens(_messages, ctx.config.model),
ctx.llm_api_handler.spice.count_prompt_tokens(_messages, ctx.config.model, ctx.config.provider),
prompt=(
prompt # Prompt can be image as well as text
if isinstance(prompt, str)
Expand Down Expand Up @@ -186,7 +189,7 @@ async def _stream_model_response(
terminate=True,
)

num_prompt_tokens = llm_api_handler.spice.count_prompt_tokens(messages, config.model)
num_prompt_tokens = llm_api_handler.spice.count_prompt_tokens(messages, config.model, config.provider)
stream.send(f"Total token count: {num_prompt_tokens}", style="info")
if num_prompt_tokens > TOKEN_COUNT_WARNING:
stream.send(
Expand Down Expand Up @@ -220,7 +223,7 @@ async def get_model_response(self) -> ParsedLLMResponse:
llm_api_handler = session_context.llm_api_handler

messages_snapshot = await self.get_messages(include_code_message=True)
tokens_used = llm_api_handler.spice.count_prompt_tokens(messages_snapshot, config.model)
tokens_used = llm_api_handler.spice.count_prompt_tokens(messages_snapshot, config.model, config.provider)
raise_if_context_exceeds_max(tokens_used)

try:
Expand All @@ -238,7 +241,7 @@ async def get_model_response(self) -> ParsedLLMResponse:
async def remaining_context(self) -> int | None:
ctx = SESSION_CONTEXT.get()
return get_max_tokens() - ctx.llm_api_handler.spice.count_prompt_tokens(
await self.get_messages(), ctx.config.model
await self.get_messages(), ctx.config.model, ctx.config.provider
)

async def can_add_to_context(self, message: str) -> bool:
Expand Down
88 changes: 73 additions & 15 deletions mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from dotenv import load_dotenv
from openai.types.chat.completion_create_params import ResponseFormat
from spice import EmbeddingResponse, Spice, SpiceMessage, SpiceResponse, StreamingSpiceResponse, TranscriptionResponse
from spice.errors import APIConnectionError, NoAPIKeyError
from spice.errors import APIConnectionError, AuthenticationError, InvalidProviderError, NoAPIKeyError
from spice.models import WHISPER_1
from spice.providers import OPEN_AI
from spice.spice import UnknownModelError, get_model_from_name
from spice.spice import UnknownModelError, get_model_from_name, get_provider_from_name

from mentat.errors import MentatError, ReturnToUser
from mentat.session_context import SESSION_CONTEXT
Expand Down Expand Up @@ -58,13 +58,20 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> RetType:
assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!"
try:
return await func(*args, **kwargs)
except AuthenticationError:
raise MentatError("Authentication error: Check your api key and try again.")
except APIConnectionError:
raise MentatError("API connection error: please check your internet connection and try again.")
raise MentatError("API connection error: Check your internet connection and try again.")
except UnknownModelError:
SESSION_CONTEXT.get().stream.send(
"Unknown model. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()
except InvalidProviderError:
SESSION_CONTEXT.get().stream.send(
"Unknown provider. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()

return async_wrapper # pyright: ignore[reportReturnType]
else:
Expand All @@ -73,13 +80,20 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> RetType:
assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!"
try:
return func(*args, **kwargs)
except AuthenticationError:
raise MentatError("Authentication error: Check your api key and try again.")
except APIConnectionError:
raise MentatError("API connection error: please check your internet connection and try again.")
raise MentatError("API connection error: Check your internet connection and try again.")
except UnknownModelError:
SESSION_CONTEXT.get().stream.send(
"Unknown model. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()
except InvalidProviderError:
SESSION_CONTEXT.get().stream.send(
"Unknown provider. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()

return sync_wrapper

Expand Down Expand Up @@ -142,19 +156,63 @@ async def initialize_client(self):
if not load_dotenv(mentat_dir_path / ".env") and not load_dotenv(ctx.cwd / ".env"):
load_dotenv()

try:
self.spice.load_provider(OPEN_AI)
except NoAPIKeyError:
from mentat.session_input import collect_user_input

user_provider = get_model_from_name(ctx.config.model).provider
if ctx.config.provider is not None:
try:
user_provider = get_provider_from_name(ctx.config.provider)
except InvalidProviderError:
ctx.stream.send(
f"Unknown provider {ctx.config.provider}. Use /config provider <provider> to set your provider.",
style="warning",
)
elif user_provider is None:
ctx.stream.send(
"No OpenAI api key detected. To avoid entering your api key on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root.",
f"Unknown model {ctx.config.model}. Use /config provider <provider> to set your provider.",
style="warning",
)
ctx.stream.send("Enter your api key:", style="info")
key = (await collect_user_input(log_input=False)).data
os.environ["OPENAI_API_KEY"] = key

# ragdaemon always needs an openai provider
providers = [OPEN_AI]
if user_provider is not None:
providers.append(user_provider)

for provider in providers:
try:
self.spice.load_provider(provider)
except NoAPIKeyError:
from mentat.session_input import collect_user_input

match provider.name:
case "open_ai" | "openai":
env_variable = "OPENAI_API_KEY"
case "anthropic":
env_variable = "ANTHROPIC_API_KEY"
case "azure":
if os.getenv("AZURE_OPENAI_ENDPOINT") is None:
ctx.stream.send(
f"No Azure OpenAI endpoint detected. To avoid entering your endpoint on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root and set AZURE_OPENAI_ENDPOINT.",
style="warning",
)
ctx.stream.send("Enter your endpoint:", style="info")
endpoint = (await collect_user_input(log_input=False)).data
os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
if os.getenv("AZURE_OPENAI_KEY") is not None:
return
env_variable = "AZURE_OPENAI_KEY"
case _:
raise MentatError(
f"No api key detected for provider {provider.name}. Create a .env file in ~/.mentat/.env or in your workspace root with your api key"
)

ctx.stream.send(
f"No {provider.name} api key detected. To avoid entering your api key on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root.",
style="warning",
)
ctx.stream.send("Enter your api key:", style="info")
key = (await collect_user_input(log_input=False)).data
os.environ[env_variable] = key

@overload
async def call_llm_api(
Expand Down Expand Up @@ -191,7 +249,7 @@ async def call_llm_api(
config = session_context.config

# Confirm that model has enough tokens remaining
tokens = self.spice.count_prompt_tokens(messages, model)
tokens = self.spice.count_prompt_tokens(messages, model, provider)
raise_if_context_exceeds_max(tokens)

with sentry_sdk.start_span(description="LLM Call") as span:
Expand Down
2 changes: 1 addition & 1 deletion mentat/revisor/revisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def revise_edit(file_edit: FileEdit):
ChatCompletionSystemMessageParam(content=f"Diff:\n{diff}", role="system"),
]
code_message = await ctx.code_context.get_code_message(
ctx.llm_api_handler.spice.count_prompt_tokens(messages, ctx.config.model)
ctx.llm_api_handler.spice.count_prompt_tokens(messages, ctx.config.model, ctx.config.provider)
)
messages.insert(1, ChatCompletionSystemMessageParam(content=code_message, role="system"))

Expand Down
135 changes: 64 additions & 71 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@

import attr
import sentry_sdk
from openai import (
APITimeoutError,
BadRequestError,
PermissionDeniedError,
RateLimitError,
)
from spice.errors import APIConnectionError, APIError, AuthenticationError

from mentat.agent_handler import AgentHandler
from mentat.auto_completer import AutoCompleter
Expand Down Expand Up @@ -162,76 +157,74 @@ async def _main(self):
code_file_manager = session_context.code_file_manager
agent_handler = session_context.agent_handler

await session_context.llm_api_handler.initialize_client()

await code_context.refresh_daemon()

check_model()
try:
await session_context.llm_api_handler.initialize_client()
await code_context.refresh_daemon()

check_model()

need_user_request = True
while True:
try:
await code_context.refresh_context_display()
if need_user_request:
# Normally, the code_file_manager pushes the edits; but when agent mode is on, we want all
# edits made between user input to be collected together.
if agent_handler.agent_enabled:
code_file_manager.history.push_edits()
stream.send(
"Use /undo to undo all changes from agent mode since last input.",
style="success",
)
message = await collect_input_with_commands()
if message.data.strip() == "":
continue
conversation.add_user_message(message.data)

parsed_llm_response = await conversation.get_model_response()
file_edits = [file_edit for file_edit in parsed_llm_response.file_edits if file_edit.is_valid()]
for file_edit in file_edits:
file_edit.resolve_conflicts()
if file_edits:
if session_context.config.revisor:
await revise_edits(file_edits)

if session_context.config.sampler:
session_context.sampler.set_active_diff()

self.send_file_edits(file_edits)
if self.apply_edits:
if not agent_handler.agent_enabled:
file_edits, need_user_request = await get_user_feedback_on_edits(file_edits)
applied_edits = await code_file_manager.write_changes_to_files(file_edits)
stream.send(
("Changes applied." if applied_edits else "No changes applied."),
style="input",
)
else:
need_user_request = True

need_user_request = True
while True:
await code_context.refresh_context_display()
try:
if need_user_request:
# Normally, the code_file_manager pushes the edits; but when agent mode is on, we want all
# edits made between user input to be collected together.
if agent_handler.agent_enabled:
code_file_manager.history.push_edits()
stream.send(
"Use /undo to undo all changes from agent mode since last input.",
style="success",
)
message = await collect_input_with_commands()
if message.data.strip() == "":
continue
conversation.add_user_message(message.data)

parsed_llm_response = await conversation.get_model_response()
file_edits = [file_edit for file_edit in parsed_llm_response.file_edits if file_edit.is_valid()]
for file_edit in file_edits:
file_edit.resolve_conflicts()
if file_edits:
if session_context.config.revisor:
await revise_edits(file_edits)

if session_context.config.sampler:
session_context.sampler.set_active_diff()

self.send_file_edits(file_edits)
if self.apply_edits:
if not agent_handler.agent_enabled:
file_edits, need_user_request = await get_user_feedback_on_edits(file_edits)
applied_edits = await code_file_manager.write_changes_to_files(file_edits)
stream.send(
("Changes applied." if applied_edits else "No changes applied."),
style="input",
)
if agent_handler.agent_enabled:
if parsed_llm_response.interrupted:
need_user_request = True
else:
need_user_request = await agent_handler.add_agent_context()
else:
need_user_request = True
stream.send(bool(file_edits), channel="edits_complete")

if agent_handler.agent_enabled:
if parsed_llm_response.interrupted:
need_user_request = True
else:
need_user_request = await agent_handler.add_agent_context()
else:
except ReturnToUser:
stream.send(None, channel="loading", terminate=True)
need_user_request = True
stream.send(bool(file_edits), channel="edits_complete")
except SessionExit:
stream.send(None, channel="client_exit")
break
except ReturnToUser:
stream.send(None, channel="loading", terminate=True)
need_user_request = True
continue
except (
APITimeoutError,
RateLimitError,
BadRequestError,
PermissionDeniedError,
) as e:
stream.send(f"Error accessing OpenAI API: {e.message}", style="error")
break
continue
except SessionExit:
stream.send(None, channel="client_exit")
except AuthenticationError:
raise MentatError("Authentication error: Check your api key and try again.")
except APIConnectionError:
raise MentatError("API connection error: Check your internet connection and try again.")
except APIError as e:
stream.send(f"Error accessing OpenAI API: {e}", style="error")

async def listen_for_session_exit(self):
await self.stream.recv(channel="session_exit")
Expand Down

0 comments on commit 03cc739

Please sign in to comment.