Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format image provider #66

Merged
merged 15 commits into from
May 1, 2023
Prev Previous commit
Next Next commit
Promotes VALID_TASKS out of class
  • Loading branch information
JasonWeill committed Apr 28, 2023
commit 754072d09d2cf2ff026cbaa950556b213e7d9f5c
10 changes: 5 additions & 5 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class CohereProvider(BaseProvider, Cohere):
pypi_package_deps = ["cohere"]
auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY")

HUGGINGFACE_HUB_VALID_TASKS = ("text2text-generation", "text-generation", "text-to-image")

class HfHubProvider(BaseProvider, HuggingFaceHub):
id = "huggingface_hub"
name = "HuggingFace Hub"
Expand All @@ -152,18 +154,16 @@ def validate_environment(cls, values: Dict) -> Dict:
try:
from huggingface_hub.inference_api import InferenceApi

VALID_TASKS = ("text2text-generation", "text-generation", "text-to-image")

repo_id = values["repo_id"]
client = InferenceApi(
repo_id=repo_id,
token=huggingfacehub_api_token,
task=values.get("task"),
)
if client.task not in VALID_TASKS:
if client.task not in HUGGINGFACE_HUB_VALID_TASKS:
raise ValueError(
f"Got invalid task {client.task}, "
f"currently only {VALID_TASKS} are supported"
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
)
values["client"] = client
except ImportError:
Expand Down Expand Up @@ -219,7 +219,7 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
else:
raise ValueError(
f"Got invalid task {self.client.task}, "
f"currently only {VALID_TASKS} are supported"
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
Expand Down