diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py index 5d8f3484c..c52c308db 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py @@ -137,7 +137,7 @@ async def handle_exc(self, e: Exception, request: InlineCompletionRequest): """ Handles an exception raised in either `handle_request()` or `handle_stream_request()`. This base class provides a default - implementation, which may be overriden by subclasses. + implementation, which may be overridden by subclasses. """ error = CompletionError( type=e.__class__.__name__, @@ -162,8 +162,6 @@ async def _handle_request(self, request: InlineCompletionRequest): async def _handle_stream_request(self, request: InlineCompletionRequest): """Private wrapper around `self.handle_stream_request()`.""" start = time.time() - await self._handle_stream_request(request) - async for chunk in self.stream(request): - self.write_message(chunk.dict()) + await self.handle_stream_request(request) latency_ms = round((time.time() - start) * 1000) self.log.info(f"Inline completion streaming completed in {latency_ms} ms.") diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index 1bf0921e5..687e41fed 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -73,9 +73,7 @@ def create_llm_chain( self.llm = llm self.llm_chain = prompt_template | llm | StrOutputParser() - async def handle_request( - self, request: InlineCompletionRequest - ) -> InlineCompletionReply: + async def handle_request(self, request: InlineCompletionRequest) -> None: """Handles an inline completion request without streaming.""" self.get_llm_chain() model_arguments = self._template_inputs_from_request(request) @@ -111,7 +109,7 @@ def _write_incomplete_reply(self, request: InlineCompletionRequest): async def handle_stream_request(self, request: InlineCompletionRequest): # first, send empty initial reply. - self._write_incomplete_reply() + self._write_incomplete_reply(request) # then, generate and stream LLM output over this connection. self.get_llm_chain() diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/__init__.py b/packages/jupyter-ai/jupyter_ai/tests/completions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py new file mode 100644 index 000000000..1b950af74 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -0,0 +1,116 @@ +import json +from types import SimpleNamespace + +from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler +from jupyter_ai.completions.models import InlineCompletionRequest +from jupyter_ai_magics import BaseProvider +from langchain_community.llms import FakeListLLM +from pytest import fixture +from tornado.httputil import HTTPServerRequest +from tornado.web import Application + + +class MockProvider(BaseProvider, FakeListLLM): + id = "my_provider" + name = "My Provider" + model_id_key = "model" + models = ["model"] + + def __init__(self, **kwargs): + kwargs["responses"] = ["Test response"] + super().__init__(**kwargs) + + +class MockCompletionHandler(DefaultInlineCompletionHandler): + def __init__(self): + self.request = HTTPServerRequest() + self.application = Application() + self.messages = [] + self.tasks = [] + self.settings["jai_config_manager"] = SimpleNamespace( + lm_provider=MockProvider, lm_provider_params={"model_id": "model"} + ) + self.settings["jai_event_loop"] = SimpleNamespace( + create_task=lambda x: self.tasks.append(x) + ) + self.settings["model_parameters"] = {} + self.llm_params = {} + self.create_llm_chain(MockProvider, {"model_id": "model"}) + + def write_message(self, message: str) -> None: # type: ignore + self.messages.append(message) + + async def handle_exc(self, e: Exception, _request: InlineCompletionRequest): + # raise all exceptions during testing rather + raise e + + +@fixture +def inline_handler() -> MockCompletionHandler: + return MockCompletionHandler() + + +async def test_on_message(inline_handler): + request = InlineCompletionRequest( + number=1, prefix="", suffix="", mime="", stream=False + ) + # Test end to end, without checking details of the replies, + # which are tested in appropriate method unit tests. + await inline_handler.on_message(json.dumps(dict(request))) + assert len(inline_handler.tasks) == 1 + await inline_handler.tasks[0] + assert len(inline_handler.messages) == 1 + + +async def test_on_message_stream(inline_handler): + stream_request = InlineCompletionRequest( + number=1, prefix="", suffix="", mime="", stream=True + ) + # Test end to end, without checking details of the replies, + # which are tested in appropriate method unit tests. + await inline_handler.on_message(json.dumps(dict(stream_request))) + assert len(inline_handler.tasks) == 1 + await inline_handler.tasks[0] + assert len(inline_handler.messages) == 3 + + +async def test_handle_request(inline_handler): + dummy_request = InlineCompletionRequest( + number=1, prefix="", suffix="", mime="", stream=False + ) + await inline_handler.handle_request(dummy_request) + # should write a single reply + assert len(inline_handler.messages) == 1 + # reply should contain a single suggestion + suggestions = inline_handler.messages[0].list.items + assert len(suggestions) == 1 + # the suggestion should include insert text from LLM + assert suggestions[0].insertText == "Test response" + + +async def test_handle_stream_request(inline_handler): + inline_handler.llm_chain = FakeListLLM(responses=["test"]) + dummy_request = InlineCompletionRequest( + number=1, prefix="", suffix="", mime="", stream=True + ) + await inline_handler.handle_stream_request(dummy_request) + + # should write three replies + assert len(inline_handler.messages) == 3 + + # first reply should be empty to start the stream + first = inline_handler.messages[0].list.items[0] + assert first.insertText == "" + assert first.isIncomplete == True + + # second reply should be a chunk containing the token + second = inline_handler.messages[1] + assert second.type == "stream" + assert second.response.insertText == "Test response" + assert second.done == False + + # third reply should be a closing chunk + third = inline_handler.messages[2] + assert third.type == "stream" + assert third.response.insertText == "Test response" + assert third.done == True