Skip to content

Commit

Permalink
Outputs image in Base64 format
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonWeill committed Apr 25, 2023
1 parent ed3891c commit 67ece14
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 6 deletions.
23 changes: 20 additions & 3 deletions examples/images.ipynb

Large diffs are not rendered by default.

16 changes: 14 additions & 2 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import json
import os
import traceback
Expand Down Expand Up @@ -34,7 +35,7 @@ def _repr_mimebundle_(self, include=None, exclude=None):
}
)

class TextWithMetadata(object):
class TextWithMetadata:

def __init__(self, text, metadata):
self.text = text
Expand All @@ -43,10 +44,21 @@ def __init__(self, text, metadata):
def _repr_mimebundle_(self, include=None, exclude=None):
return ({'text/plain': self.text}, self.metadata)


class Base64Image:
def __init__(self, base64bytes, metadata, mimeType='image/jpeg'):
self.data = base64.b64decode(base64bytes);
self.mimeType = mimeType
self.metadata = metadata

def _repr_mimebundle_(self, include=None, exclude=None):
return ({self.mimeType: self.data}, self.metadata)


DISPLAYS_BY_FORMAT = {
"code": None,
"html": HTML,
"image": Image,
"image": Base64Image,
"markdown": Markdown,
"math": Math,
"md": Markdown,
Expand Down
46 changes: 46 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import ClassVar, Dict, List, Union, Literal, Optional

import base64

import io

from langchain.schema import BaseLanguageModel as BaseLangchainProvider
from langchain.llms import (
AI21,
Expand Down Expand Up @@ -170,6 +174,48 @@ def validate_environment(cls, values: Dict) -> Dict:
"Please install it with `pip install huggingface_hub`."
)
return values

# Handle image outputs
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string or image generated by the model.
Example:
.. code-block:: python
response = hf("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
response = self.client(inputs=prompt, params=_model_kwargs)
# Custom code for responding to image generation responses
if self.client.task == "text-to-image":
buffer = io.BytesIO()
response.save(buffer, format='JPEG', quality=75)
return base64.b64encode(buffer.getvalue())

if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")
if self.client.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]
elif self.client.task == "text2text-generation":
text = response[0]["generated_text"]
else:
raise ValueError(
f"Got invalid task {self.client.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text

class OpenAIProvider(BaseProvider, OpenAI):
id = "openai"
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ dependencies = [
"ipython",
"pydantic",
"importlib_metadata~=5.2.0",
"langchain~=0.0.144"
"langchain~=0.0.144",
"Pillow"
]

[project.optional-dependencies]
Expand Down

0 comments on commit 67ece14

Please sign in to comment.