Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format image provider #66

Merged
merged 15 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ Jupyter AI supports the following model providers:
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `anthropic` |
| Cohere | `cohere` | `COHERE_API_KEY` | `cohere` |
| HuggingFace Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets` |
| HuggingFace Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `openai` |
| SageMaker Endpoints | `sagemaker-endpoint` | N/A | `boto3` |

You need the `pillow` Python package to use HuggingFace Hub's text-to-image models.

To use SageMaker's models, you will need to authenticate via
[boto3](https://github.com/boto/boto3).

Expand Down Expand Up @@ -294,6 +296,7 @@ an `%%ai` command will be formatted as markdown by default. You can override thi
using the `-f` or `--format` argument to your magic command. Valid formats include:

- `code`
- `image` (for HuggingFace Hub's text-to-image models only)
- `markdown`
- `math`
- `html`
Expand Down
144 changes: 144 additions & 0 deletions examples/images.ipynb

Large diffs are not rendered by default.

24 changes: 21 additions & 3 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import base64
import json
import os
import re
import traceback
import warnings
from typing import Optional

from importlib_metadata import entry_points
from IPython import get_ipython
from IPython.core.magic import Magics, magics_class, line_cell_magic
from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring
from IPython.display import HTML, Markdown, Math, JSON
from IPython.display import HTML, Image, JSON, Markdown, Math

from .providers import BaseProvider

Expand All @@ -34,7 +36,7 @@ def _repr_mimebundle_(self, include=None, exclude=None):
}
)

class TextWithMetadata(object):
class TextWithMetadata:

def __init__(self, text, metadata):
self.text = text
Expand All @@ -43,9 +45,22 @@ def __init__(self, text, metadata):
def _repr_mimebundle_(self, include=None, exclude=None):
return ({'text/plain': self.text}, self.metadata)


class Base64Image:
def __init__(self, mimeData, metadata):
mimeDataParts = mimeData.split(',')
self.data = base64.b64decode(mimeDataParts[1]);
self.mimeType = mimeDataParts[0].removesuffix(';base64')
self.metadata = metadata

def _repr_mimebundle_(self, include=None, exclude=None):
return ({self.mimeType: self.data}, self.metadata)


DISPLAYS_BY_FORMAT = {
"code": None,
"html": HTML,
"image": Base64Image,
"markdown": Markdown,
"math": Math,
"md": Markdown,
Expand All @@ -60,6 +75,7 @@ def _repr_mimebundle_(self, include=None, exclude=None):
PROMPT_TEMPLATES_BY_FORMAT = {
"code": '{prompt}\n\nProduce output as source code only, with no text or explanation before or after it.',
"html": '{prompt}\n\nProduce output in HTML format only, with no markup before or afterward.',
"image": '{prompt}\n\nProduce output as an image only, with no text before or after it.',
"markdown": MARKDOWN_PROMPT_TEMPLATE,
"md": MARKDOWN_PROMPT_TEMPLATE,
"math": '{prompt}\n\nProduce output in LaTeX format only, with $$ at the beginning and end.',
Expand Down Expand Up @@ -101,6 +117,8 @@ def __init__(self, shell):
try:
Provider = model_provider_ep.load()
except:
print(f"Unable to load entry point {model_provider_ep.name}");
traceback.print_exc()
continue
self.providers[Provider.id] = Provider

Expand Down Expand Up @@ -286,7 +304,7 @@ def _get_provider(self, provider_id: Optional[str]) -> BaseProvider:
optionally prefixed with the ID of the model provider, delimited
by a colon.""")
@argument('-f', '--format',
choices=["code", "markdown", "html", "json", "math", "md", "text"],
choices=["code", "html", "image", "json", "markdown", "math", "md", "text"],
nargs="?",
default="markdown",
help="""IPython display to use when rendering output. [default="markdown"]""")
Expand Down
96 changes: 94 additions & 2 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import ClassVar, List, Union, Literal, Optional
from typing import ClassVar, Dict, List, Union, Literal, Optional

import base64

import io

from langchain.schema import BaseLanguageModel as BaseLangchainProvider
from langchain.llms import (
Expand All @@ -10,8 +14,9 @@
OpenAIChat,
SagemakerEndpoint
)
from langchain.utils import get_from_dict_or_env

from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, root_validator
from langchain.chat_models import ChatOpenAI

class EnvAuthStrategy(BaseModel):
Expand Down Expand Up @@ -126,6 +131,8 @@ class CohereProvider(BaseProvider, Cohere):
pypi_package_deps = ["cohere"]
auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY")

HUGGINGFACE_HUB_VALID_TASKS = ("text2text-generation", "text-generation", "text-to-image")

class HfHubProvider(BaseProvider, HuggingFaceHub):
id = "huggingface_hub"
name = "HuggingFace Hub"
Expand All @@ -137,6 +144,91 @@ class HfHubProvider(BaseProvider, HuggingFaceHub):
pypi_package_deps = ["huggingface_hub", "ipywidgets"]
auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN")

# Override the parent's validate_environment with a custom list of valid tasks
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
try:
from huggingface_hub.inference_api import InferenceApi

repo_id = values["repo_id"]
client = InferenceApi(
repo_id=repo_id,
token=huggingfacehub_api_token,
task=values.get("task"),
)
if client.task not in HUGGINGFACE_HUB_VALID_TASKS:
raise ValueError(
f"Got invalid task {client.task}, "
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
)
values["client"] = client
except ImportError:
raise ValueError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
return values

# Handle image outputs
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
JasonWeill marked this conversation as resolved.
Show resolved Hide resolved
"""Call out to HuggingFace Hub's inference endpoint.

Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.

Returns:
The string or image generated by the model.

Example:
.. code-block:: python

response = hf("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
response = self.client(inputs=prompt, params=_model_kwargs)

if type(response) is dict and "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")

# Custom code for responding to image generation responses
if self.client.task == "text-to-image":
imageFormat = response.format # Presume it's a PIL ImageFile
mimeType = ''
if (imageFormat == 'JPEG'):
mimeType = 'image/jpeg'
elif (imageFormat == 'PNG'):
mimeType = 'image/png'
elif (imageFormat == 'GIF'):
mimeType = 'image/gif'
else:
raise ValueError(f"Unrecognized image format {imageFormat}")

buffer = io.BytesIO()
response.save(buffer, format=imageFormat)
# Encode image data to Base64 bytes, then decode bytes to str
return (mimeType + ';base64,' + base64.b64encode(buffer.getvalue()).decode())

if self.client.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]
elif self.client.task == "text2text-generation":
text = response[0]["generated_text"]
else:
raise ValueError(
f"Got invalid task {self.client.task}, "
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text

class OpenAIProvider(BaseProvider, OpenAI):
id = "openai"
name = "OpenAI"
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ all = [
"cohere",
"huggingface_hub",
"ipywidgets",
"pillow",
"openai",
"boto3"
]
Expand Down