Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document how to create completions using full notebook content #777

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions docs/source/developers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,94 @@ class MyCompletionProvider(BaseProvider, FakeListLLM):
)
```


#### Using the full notebook content for completions

The `InlineCompletionRequest` contains the `path` of the current document (file or notebook).
Inline completion providers can use this path to extract the content of the notebook from the disk,
however such content may be outdated if the user has not saved the notebook recently.

The accuracy of the suggestions can be slightly improved by combining the potentially outdated content of previous/following cells
with the `prefix` and `suffix` which describe the up-to-date state of the current cell (identified by `cell_id`).

Still, reading the full notebook from the disk may be slow for larger notebooks, which conflicts with the low latency requirement of inline completion.

A better approach is to use the live copy of the notebook document that is persisted on the jupyter-server when *collaborative* document models are enabled.
Two packages need to be installed to access the collaborative models:
- `jupyter-server-ydoc` (>= 1.0) stores the collaborative models in the jupyter-server on runtime
- `jupyter-docprovider` (>= 1.0) reconfigures JupyterLab/Notebook to use the collaborative models

Both packages are automatically installed with `jupyter-collaboration` (in v3.0 or newer), however installing `jupyter-collaboration` is not required to take advantage of *collaborative* models.

The snippet below demonstrates how to retrieve the content of all cells of a given type from the in-memory copy of the collaborative model (without additional disk reads).

```python
from jupyter_ydoc import YNotebook


class MyCompletionProvider(BaseProvider, FakeListLLM):
id = "my_provider"
name = "My Provider"
model_id_key = "model"
models = ["model_a"]

def __init__(self, **kwargs):
kwargs["responses"] = ["This fake response will not be used for completion"]
super().__init__(**kwargs)

async def _get_prefix_and_suffix(self, request: InlineCompletionRequest):
prefix = request.prefix
suffix = request.suffix.strip()

server_ydoc = self.server_settings.get("jupyter_server_ydoc", None)
if not server_ydoc:
# fallback to prefix/suffix from single cell
return prefix, suffix

is_notebook = request.path.endswith("ipynb")
document = await server_ydoc.get_document(
path=request.path,
content_type="notebook" if is_notebook else "file",
file_format="json" if is_notebook else "text"
)
if not document or not isinstance(document, YNotebook):
return prefix, suffix

cell_type = "markdown" if request.language == "markdown" else "code"

is_before_request_cell = True
before = []
after = [suffix]

for cell in document.ycells:
if is_before_request_cell and cell["id"] == request.cell_id:
is_before_request_cell = False
continue
if cell["cell_type"] != cell_type:
continue
source = cell["source"].to_py()
if is_before_request_cell:
before.append(source)
else:
after.append(source)

before.append(prefix)
prefix = "\n\n".join(before)
suffix = "\n\n".join(after)
return prefix, suffix

async def generate_inline_completions(self, request: InlineCompletionRequest):
prefix, suffix = await self._get_prefix_and_suffix(request)

return InlineCompletionReply(
list=InlineCompletionList(items=[
{"insertText": your_llm_function(prefix, suffix)}
]),
reply_to=request.number,
)
```


## Prompt templates

Each provider can define **prompt templates** for each supported format. A prompt
Expand Down
8 changes: 8 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io
import json
from concurrent.futures import ThreadPoolExecutor
from types import MappingProxyType
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -265,6 +266,13 @@ class Config:
provider is selected.
"""

server_settings: ClassVar[Optional[MappingProxyType[str, Any]]] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if there's a way to raise an exception when this class attr more than once? This is not a PR blocker, but it would be a good sanity check to include for developers that are accidentally assigning to this class attr again in their own server extension.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this would be nice to have. It would be trivial for an instance attribute with a property descriptor. I think it is feasible for a class instance too by using a metaclass. Taking a look

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will merge this PR and open a separate follow-up PR to improve this.

"""Settings passed on from jupyter-ai package.

The same server settings are shared between all providers.
Providers are not allowed to mutate this dictionary.
"""

@classmethod
def chat_models(self):
"""Models which are suitable for chat."""
Expand Down
8 changes: 7 additions & 1 deletion packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import re
import time
import types

from dask.distributed import Client as DaskClient
from importlib_metadata import entry_points
from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai_magics import JupyternautPersona
from jupyter_ai_magics import BaseProvider, JupyternautPersona
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp
from tornado.web import StaticFileHandler
Expand Down Expand Up @@ -202,6 +203,11 @@ def initialize_settings(self):
defaults=defaults,
)

# Expose a subset of settings as read-only to the providers
BaseProvider.server_settings = types.MappingProxyType(
self.serverapp.web_app.settings
)

self.log.info("Registered providers.")

self.log.info(f"Registered {self.name} server extension")
Expand Down
Loading