Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add provider checking based on model name and provider #571

Merged
merged 6 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
validate_and_format_path,
)
from mentat.interval import parse_intervals, split_intervals_from_path
from mentat.llm_api_handler import get_max_tokens
from mentat.llm_api_handler import api_guard, get_max_tokens
from mentat.session_context import SESSION_CONTEXT
from mentat.session_stream import SessionStream
from mentat.utils import get_relative_path, mentat_dir_path
Expand Down Expand Up @@ -59,6 +59,7 @@ def __init__(
self.include_files: Dict[Path, List[CodeFeature]] = {}
self.ignore_files: Set[Path] = set()

@api_guard
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
async def refresh_daemon(self):
"""Call before interacting with context to ensure daemon is up to date."""

Expand Down
78 changes: 60 additions & 18 deletions mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
from dotenv import load_dotenv
from openai.types.chat.completion_create_params import ResponseFormat
from spice import EmbeddingResponse, Spice, SpiceMessage, SpiceResponse, StreamingSpiceResponse, TranscriptionResponse
from spice.errors import APIConnectionError, NoAPIKeyError
from spice.errors import APIConnectionError, AuthenticationError, InvalidProviderError, NoAPIKeyError
from spice.models import WHISPER_1
from spice.providers import OPEN_AI
from spice.spice import UnknownModelError, get_model_from_name
from spice.spice import UnknownModelError, get_model_from_name, get_provider_from_name

from mentat.errors import MentatError, ReturnToUser
from mentat.session_context import SESSION_CONTEXT
Expand Down Expand Up @@ -58,13 +57,20 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> RetType:
assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!"
try:
return await func(*args, **kwargs)
except AuthenticationError:
raise MentatError("Authentication error: Check your api key and try again.")
except APIConnectionError:
raise MentatError("API connection error: please check your internet connection and try again.")
raise MentatError("API connection error: Check your internet connection and try again.")
except UnknownModelError:
SESSION_CONTEXT.get().stream.send(
"Unknown model. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()
except InvalidProviderError:
SESSION_CONTEXT.get().stream.send(
"Unknown provider. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()

return async_wrapper # pyright: ignore[reportReturnType]
else:
Expand All @@ -73,13 +79,20 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> RetType:
assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!"
try:
return func(*args, **kwargs)
except AuthenticationError:
raise MentatError("Authentication error: Check your api key and try again.")
except APIConnectionError:
raise MentatError("API connection error: please check your internet connection and try again.")
raise MentatError("API connection error: Check your internet connection and try again.")
except UnknownModelError:
SESSION_CONTEXT.get().stream.send(
"Unknown model. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()
except InvalidProviderError:
SESSION_CONTEXT.get().stream.send(
"Unknown provider. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()

return sync_wrapper

Expand Down Expand Up @@ -142,19 +155,48 @@ async def initialize_client(self):
if not load_dotenv(mentat_dir_path / ".env") and not load_dotenv(ctx.cwd / ".env"):
load_dotenv()

try:
self.spice.load_provider(OPEN_AI)
except NoAPIKeyError:
from mentat.session_input import collect_user_input

ctx.stream.send(
"No OpenAI api key detected. To avoid entering your api key on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root.",
style="warning",
)
ctx.stream.send("Enter your api key:", style="info")
key = (await collect_user_input(log_input=False)).data
os.environ["OPENAI_API_KEY"] = key
provider = get_model_from_name(ctx.config.model).provider
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
if ctx.config.provider is not None:
try:
provider = get_provider_from_name(ctx.config.provider)
except InvalidProviderError:
ctx.stream.send(
f"Unknown provider {ctx.config.provider}. Use /config provider <provider> to set your provider.",
style="warning",
)
elif provider is None:
ctx.stream.send(f"Unknown model {ctx.config.model}. Use /config provider <provider> to set your provider.")

if provider is not None:
try:
self.spice.load_provider(provider)
except NoAPIKeyError:
from mentat.session_input import collect_user_input

match provider.name:
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
case "open_ai":
env_variable = "OPENAI_API_KEY"
case "anthropic":
env_variable = "ANTHROPIC_API_KEY"
case "azure":
if os.getenv("AZURE_OPENAI_ENDPOINT") is None:
raise MentatError(
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
f"No AZURE_OPENAI_ENDPOINT detected. Create a .env file in ~/.mentat/.env or in your workspace root with your Azure endpoint."
)
env_variable = "AZURE_OPENAI_KEY"
case _:
raise MentatError(
f"No api key detected for provider {provider.name}. Create a .env file in ~/.mentat/.env or in your workspace root with your api key"
)

ctx.stream.send(
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
f"No {provider.name} api key detected. To avoid entering your api key on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root.",
style="warning",
)
ctx.stream.send("Enter your api key:", style="info")
key = (await collect_user_input(log_input=False)).data
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
os.environ[env_variable] = key

@overload
async def call_llm_api(
Expand Down
Loading