Skip to content

Commit

Permalink
Allow usage without NVIDIA partner package (#622)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Feb 6, 2024
1 parent bbc08cd commit 6aeb87f
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 28 deletions.
1 change: 0 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
BedrockChatProvider,
BedrockProvider,
ChatAnthropicProvider,
ChatNVIDIAProvider,
ChatOpenAIProvider,
CohereProvider,
GPT4AllProvider,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from jupyter_ai_magics.providers import BaseProvider, EnvAuthStrategy
from langchain_nvidia_ai_endpoints import ChatNVIDIA


class ChatNVIDIAProvider(BaseProvider, ChatNVIDIA):
id = "nvidia-chat"
name = "NVIDIA"
models = [
"playground_llama2_70b",
"playground_nemotron_steerlm_8b",
"playground_mistral_7b",
"playground_nv_llama2_rlhf_70b",
"playground_llama2_13b",
"playground_steerlm_llama_70b",
"playground_llama2_code_13b",
"playground_yi_34b",
"playground_mixtral_8x7b",
"playground_neva_22b",
"playground_llama2_code_34b",
]
model_id_key = "model"
auth_strategy = EnvAuthStrategy(name="NVIDIA_API_KEY")
pypi_package_deps = ["langchain_nvidia_ai_endpoints"]
21 changes: 0 additions & 21 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
OpenAI,
SagemakerEndpoint,
)
from langchain_nvidia_ai_endpoints import ChatNVIDIA

# this is necessary because `langchain.pydantic_v1.main` does not include
# `ModelMetaclass`, as it is not listed in `__all__` by the `pydantic.main`
Expand Down Expand Up @@ -859,23 +858,3 @@ class QianfanProvider(BaseProvider, QianfanChatEndpoint):
model_id_key = "model_name"
pypi_package_deps = ["qianfan"]
auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"])


class ChatNVIDIAProvider(BaseProvider, ChatNVIDIA):
id = "nvidia-chat"
name = "NVIDIA"
models = [
"playground_llama2_70b",
"playground_nemotron_steerlm_8b",
"playground_mistral_7b",
"playground_nv_llama2_rlhf_70b",
"playground_llama2_13b",
"playground_steerlm_llama_70b",
"playground_llama2_code_13b",
"playground_yi_34b",
"playground_mixtral_8x7b",
"playground_neva_22b",
"playground_llama2_code_34b",
]
model_id_key = "model"
auth_strategy = EnvAuthStrategy(name="NVIDIA_API_KEY")
16 changes: 11 additions & 5 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,22 @@ def get_lm_providers(
restrictions = {"allowed_providers": None, "blocked_providers": None}
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.model_providers")
for model_provider_ep in model_provider_eps:
provider_ep_group = eps.select(group="jupyter_ai.model_providers")
for provider_ep in provider_ep_group:
try:
provider = model_provider_ep.load()
provider = provider_ep.load()
except ImportError as e:
log.warning(
f"Unable to load model provider `{provider_ep.name}`. Please install the `{e.name}` package."
)
continue
except Exception as e:
log.error(
f"Unable to load model provider class from entry point `{model_provider_ep.name}`: %s.",
e,
f"Unable to load model provider `{provider_ep.name}`. Printing full exception below."
)
log.exception(e)
continue

if not is_provider_allowed(provider.id, restrictions):
log.info(f"Skipping blocked provider `{provider.id}`.")
continue
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ amazon-bedrock = "jupyter_ai_magics:BedrockProvider"
anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider"
amazon-bedrock-chat = "jupyter_ai_magics:BedrockChatProvider"
qianfan = "jupyter_ai_magics:QianfanProvider"
nvidia-chat = "jupyter_ai_magics:ChatNVIDIAProvider"
nvidia-chat = "jupyter_ai_magics.partner_providers.nvidia:ChatNVIDIAProvider"

[project.entry-points."jupyter_ai.embeddings_model_providers"]
bedrock = "jupyter_ai_magics:BedrockEmbeddingsProvider"
Expand Down

0 comments on commit 6aeb87f

Please sign in to comment.