Skip to content

Commit

Permalink
Format image provider (jupyterlab#66)
Browse files Browse the repository at this point in the history
* Adds image format

* WIP: New provider for huggingface_image

* Adds image provider to __init__

* Improve magics.

* Change raw to text.
* Add provider/model metadata to output.

* Fixing metadata handling.

* Updates user docs

* WIP: Prints debug output on initial load

* Removes debug info from load

* Removes separate provider for images

* Outputs image in Base64 format

* Removes HF image provider, updates docs

* Removes limitation of only supporting JPEG format

* Promotes VALID_TASKS out of class

* Updates error handling, image generation example notebook

* Updates dependencies per @dlqqq

---------

Co-authored-by: Brian E. Granger <brgrange@amazon.com>
Co-authored-by: Brian E. Granger <ellisonbg@gmail.com>
  • Loading branch information
3 people committed May 1, 2023
1 parent 469a480 commit 2c2c599
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 6 deletions.
5 changes: 4 additions & 1 deletion docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ Jupyter AI supports the following model providers:
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `anthropic` |
| Cohere | `cohere` | `COHERE_API_KEY` | `cohere` |
| HuggingFace Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets` |
| HuggingFace Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `openai` |
| SageMaker Endpoints | `sagemaker-endpoint` | N/A | `boto3` |

You need the `pillow` Python package to use HuggingFace Hub's text-to-image models.

To use SageMaker's models, you will need to authenticate via
[boto3](https://github.com/boto/boto3).

Expand Down Expand Up @@ -294,6 +296,7 @@ an `%%ai` command will be formatted as markdown by default. You can override thi
using the `-f` or `--format` argument to your magic command. Valid formats include:

- `code`
- `image` (for HuggingFace Hub's text-to-image models only)
- `markdown`
- `math`
- `html`
Expand Down
144 changes: 144 additions & 0 deletions examples/images.ipynb

Large diffs are not rendered by default.

24 changes: 21 additions & 3 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import base64
import json
import os
import re
import traceback
import warnings
from typing import Optional

from importlib_metadata import entry_points
from IPython import get_ipython
from IPython.core.magic import Magics, magics_class, line_cell_magic
from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring
from IPython.display import HTML, Markdown, Math, JSON
from IPython.display import HTML, Image, JSON, Markdown, Math

from .providers import BaseProvider

Expand All @@ -34,7 +36,7 @@ def _repr_mimebundle_(self, include=None, exclude=None):
}
)

class TextWithMetadata(object):
class TextWithMetadata:

def __init__(self, text, metadata):
self.text = text
Expand All @@ -43,9 +45,22 @@ def __init__(self, text, metadata):
def _repr_mimebundle_(self, include=None, exclude=None):
return ({'text/plain': self.text}, self.metadata)


class Base64Image:
def __init__(self, mimeData, metadata):
mimeDataParts = mimeData.split(',')
self.data = base64.b64decode(mimeDataParts[1]);
self.mimeType = mimeDataParts[0].removesuffix(';base64')
self.metadata = metadata

def _repr_mimebundle_(self, include=None, exclude=None):
return ({self.mimeType: self.data}, self.metadata)


DISPLAYS_BY_FORMAT = {
"code": None,
"html": HTML,
"image": Base64Image,
"markdown": Markdown,
"math": Math,
"md": Markdown,
Expand All @@ -60,6 +75,7 @@ def _repr_mimebundle_(self, include=None, exclude=None):
PROMPT_TEMPLATES_BY_FORMAT = {
"code": '{prompt}\n\nProduce output as source code only, with no text or explanation before or after it.',
"html": '{prompt}\n\nProduce output in HTML format only, with no markup before or afterward.',
"image": '{prompt}\n\nProduce output as an image only, with no text before or after it.',
"markdown": MARKDOWN_PROMPT_TEMPLATE,
"md": MARKDOWN_PROMPT_TEMPLATE,
"math": '{prompt}\n\nProduce output in LaTeX format only, with $$ at the beginning and end.',
Expand Down Expand Up @@ -101,6 +117,8 @@ def __init__(self, shell):
try:
Provider = model_provider_ep.load()
except:
print(f"Unable to load entry point {model_provider_ep.name}");
traceback.print_exc()
continue
self.providers[Provider.id] = Provider

Expand Down Expand Up @@ -286,7 +304,7 @@ def _get_provider(self, provider_id: Optional[str]) -> BaseProvider:
optionally prefixed with the ID of the model provider, delimited
by a colon.""")
@argument('-f', '--format',
choices=["code", "markdown", "html", "json", "math", "md", "text"],
choices=["code", "html", "image", "json", "markdown", "math", "md", "text"],
nargs="?",
default="markdown",
help="""IPython display to use when rendering output. [default="markdown"]""")
Expand Down
96 changes: 94 additions & 2 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import ClassVar, List, Union, Literal, Optional
from typing import ClassVar, Dict, List, Union, Literal, Optional

import base64

import io

from langchain.schema import BaseLanguageModel as BaseLangchainProvider
from langchain.llms import (
Expand All @@ -10,8 +14,9 @@
OpenAIChat,
SagemakerEndpoint
)
from langchain.utils import get_from_dict_or_env

from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, root_validator
from langchain.chat_models import ChatOpenAI

class EnvAuthStrategy(BaseModel):
Expand Down Expand Up @@ -126,6 +131,8 @@ class CohereProvider(BaseProvider, Cohere):
pypi_package_deps = ["cohere"]
auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY")

HUGGINGFACE_HUB_VALID_TASKS = ("text2text-generation", "text-generation", "text-to-image")

class HfHubProvider(BaseProvider, HuggingFaceHub):
id = "huggingface_hub"
name = "HuggingFace Hub"
Expand All @@ -137,6 +144,91 @@ class HfHubProvider(BaseProvider, HuggingFaceHub):
pypi_package_deps = ["huggingface_hub", "ipywidgets"]
auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN")

# Override the parent's validate_environment with a custom list of valid tasks
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
try:
from huggingface_hub.inference_api import InferenceApi

repo_id = values["repo_id"]
client = InferenceApi(
repo_id=repo_id,
token=huggingfacehub_api_token,
task=values.get("task"),
)
if client.task not in HUGGINGFACE_HUB_VALID_TASKS:
raise ValueError(
f"Got invalid task {client.task}, "
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
)
values["client"] = client
except ImportError:
raise ValueError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
return values

# Handle image outputs
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string or image generated by the model.
Example:
.. code-block:: python
response = hf("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
response = self.client(inputs=prompt, params=_model_kwargs)

if type(response) is dict and "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")

# Custom code for responding to image generation responses
if self.client.task == "text-to-image":
imageFormat = response.format # Presume it's a PIL ImageFile
mimeType = ''
if (imageFormat == 'JPEG'):
mimeType = 'image/jpeg'
elif (imageFormat == 'PNG'):
mimeType = 'image/png'
elif (imageFormat == 'GIF'):
mimeType = 'image/gif'
else:
raise ValueError(f"Unrecognized image format {imageFormat}")

buffer = io.BytesIO()
response.save(buffer, format=imageFormat)
# Encode image data to Base64 bytes, then decode bytes to str
return (mimeType + ';base64,' + base64.b64encode(buffer.getvalue()).decode())

if self.client.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]
elif self.client.task == "text2text-generation":
text = response[0]["generated_text"]
else:
raise ValueError(
f"Got invalid task {self.client.task}, "
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text

class OpenAIProvider(BaseProvider, OpenAI):
id = "openai"
name = "OpenAI"
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ all = [
"cohere",
"huggingface_hub",
"ipywidgets",
"pillow",
"openai",
"boto3"
]
Expand Down

0 comments on commit 2c2c599

Please sign in to comment.