Skip to content

Commit

Permalink
Basic collaborative chat (jupyterlab#58)
Browse files Browse the repository at this point in the history
* add .jupyter_ystore.db to .gitignore

* connect chat UI to use websockets handlers

* remove old Chat REST API

* remove console logs

* do not use identity provider username for client ID

* remove old messages property
  • Loading branch information
dlqqq committed Apr 13, 2023
1 parent 67fde68 commit 3be6ba0
Show file tree
Hide file tree
Showing 10 changed files with 352 additions and 239 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,6 @@ playground/
# reserve path for a dev script
dev.sh

.vscode
.vscode

.jupyter_ystore.db
7 changes: 4 additions & 3 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import queue
from jupyter_server.extension.application import ExtensionApp
from langchain import ConversationChain
from .handlers import ChatHandler, ChatHistoryHandler, PromptAPIHandler, TaskAPIHandler, ChatAPIHandler
from .handlers import ChatHandler, ChatHistoryHandler, PromptAPIHandler, TaskAPIHandler
from importlib_metadata import entry_points
import inspect
from .engine import BaseModelEngine
Expand All @@ -20,7 +20,6 @@ class AiExtension(ExtensionApp):
name = "jupyter_ai"
handlers = [
("api/ai/prompt", PromptAPIHandler),
(r"api/ai/chat/?", ChatAPIHandler),
(r"api/ai/tasks/?", TaskAPIHandler),
(r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler),
(r"api/ai/chats/?", ChatHandler),
Expand Down Expand Up @@ -113,5 +112,7 @@ def initialize_settings(self):

# Store chat clients in a dictionary
self.settings["chat_clients"] = {}
self.settings["chat_handlers"] = {}


# store chat messages in memory for now
self.settings["chat_history"] = []
152 changes: 88 additions & 64 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import asdict
import json
from typing import Optional

from typing import Dict, List
import tornado
import uuid

from tornado.web import HTTPError
from pydantic import ValidationError

Expand All @@ -12,8 +13,8 @@
from jupyter_server.utils import ensure_async

from .task_manager import TaskManager
from .models import ChatHistory, PromptRequest, ChatRequest
from langchain.schema import _message_to_dict, HumanMessage, AIMessage
from .models import ChatHistory, PromptRequest, ChatRequest, ChatMessage, AgentChatMessage, HumanChatMessage, ConnectionMessage, ChatClient
from langchain.schema import HumanMessage

class APIHandler(BaseAPIHandler):
@property
Expand Down Expand Up @@ -60,26 +61,6 @@ async def post(self):
"insertion_mode": task.insertion_mode
}))

class ChatAPIHandler(APIHandler):
@tornado.web.authenticated
async def post(self):
try:
request = ChatRequest(**self.get_json_body())
except ValidationError as e:
self.log.exception(e)
raise HTTPError(500, str(e)) from e

if not self.openai_chat:
raise HTTPError(500, "No chat models available.")

result = await ensure_async(self.openai_chat.agenerate([request.prompt]))
output = result.generations[0][0].text
self.openai_chat.append_exchange(request.prompt, output)

self.finish(json.dumps({
"output": output,
}))

class TaskAPIHandler(APIHandler):
@tornado.web.authenticated
async def get(self, id=None):
Expand All @@ -106,25 +87,24 @@ def chat_provider(self):
if self._chat_provider is None:
self._chat_provider = self.settings["chat_provider"]
return self._chat_provider

@property
def messages(self):
self._messages = self.chat_provider.memory.chat_memory.messages or []
return self._messages

@property
def chat_history(self):
return self.settings["chat_history"]

@chat_history.setter
def _chat_history_setter(self, new_history):
self.settings["chat_history"] = new_history

@tornado.web.authenticated
async def get(self):
messages = []
for message in self.messages:
messages.append(message)
history = ChatHistory(messages=messages)

self.finish(history.json(models_as_dict=False))
history = ChatHistory(messages=self.chat_history)
self.finish(history.json())

@tornado.web.authenticated
async def delete(self):
self.chat_provider.memory.chat_memory.clear()
self.messages = []
self.chat_history = []
self.set_status(204)
self.finish()

Expand Down Expand Up @@ -154,18 +134,26 @@ def chat_message_queue(self):
return self._chat_message_queue

@property
def messages(self):
self._messages = self.chat_provider.memory.chat_memory.messages or []
return self._messages
def chat_handlers(self) -> Dict[str, 'ChatHandler']:
"""Dictionary mapping client IDs to their WebSocket handler
instances."""
return self.settings["chat_handlers"]

def add_chat_client(self, username):
self.settings["chat_clients"][username] = self
self.log.debug("Clients are : %s", self.settings["chat_clients"].keys())
@property
def chat_clients(self) -> Dict[str, ChatClient]:
"""Dictionary mapping client IDs to their ChatClient objects that store
metadata."""
return self.settings["chat_clients"]

def remove_chat_client(self, username):
self.settings["chat_clients"][username] = None
self.log.debug("Chat clients: %s", self.settings['chat_clients'].keys())
@property
def chat_client(self) -> ChatClient:
"""Returns ChatClient object associated with the current connection."""
return self.chat_clients[self.client_id]

@property
def chat_history(self) -> List[ChatMessage]:
return self.settings["chat_history"]

def initialize(self):
self.log.debug("Initializing websocket connection %s", self.request.path)

Expand All @@ -188,24 +176,45 @@ async def get(self, *args, **kwargs):
res = super().get(*args, **kwargs)
await res

def generate_client_id(self):
"""Generates a client ID to identify the current WS connection."""
# if collaborative mode is enabled, each client already has a UUID
# collaborative = self.config.get("LabApp", {}).get("collaborative", False)
# if collaborative:
# return self.current_user.username

# if collaborative mode is not enabled, each client is assigned a UUID
return uuid.uuid4().hex

def open(self):
self.log.debug("Client with user %s connected...", self.current_user.username)
self.add_chat_client(self.current_user.username)
"""Handles opening of a WebSocket connection. Client ID can be retrieved
from `self.client_id`."""

client_id = self.generate_client_id()
chat_client_kwargs = {k: v for k, v in asdict(self.current_user).items() if k != "username"}

def broadcast_message(self, message: any, exclude_current_user: Optional[bool] = False):
"""Broadcasts message to all connected clients,
optionally excluding the current user
self.chat_handlers[client_id] = self
self.chat_clients[client_id] = ChatClient(**chat_client_kwargs, id=client_id)
self.client_id = client_id
self.write_message(ConnectionMessage(client_id=client_id).dict())

self.log.info(f"Client connected. ID: {client_id}")
self.log.debug("Clients are : %s", self.chat_handlers.keys())

def broadcast_message(self, message: ChatMessage):
"""Broadcasts message to all connected clients, optionally excluding the
current user. Appends message to `self.chat_history`.
"""

self.log.debug("Broadcasting message: %s to all clients...", message)
client_names = self.settings["chat_clients"].keys()
if exclude_current_user:
client_names = client_names - [self.current_user.username]
client_ids = self.chat_handlers.keys()

for username in client_names:
client = self.settings["chat_clients"][username]
for client_id in client_ids:
client = self.chat_handlers[client_id]
if client:
client.write_message(message)
client.write_message(message.dict())

self.chat_history.append(message)

async def on_message(self, message):
self.log.debug("Message recieved: %s", message)
Expand All @@ -217,24 +226,39 @@ async def on_message(self, message):
self.log.error(e)
return

# message sent to the agent instance
message = HumanMessage(
content=chat_request.prompt,
additional_kwargs=dict(user=asdict(self.current_user))
)
data = json.dumps(_message_to_dict(message))
# message broadcast to chat clients
chat_message_id = str(uuid.uuid4())
chat_message = HumanChatMessage(
id=chat_message_id,
body=chat_request.prompt,
client=self.chat_client,
)

# broadcast the message to other clients
self.broadcast_message(message=data, exclude_current_user=True)
self.broadcast_message(message=chat_message)

# process the message
response = await ensure_async(self.chat_provider.apredict(input=message.content))

response = AIMessage(
content=response
agent_message = AgentChatMessage(
id=str(uuid.uuid4()),
body=response,
reply_to=chat_message_id
)

# broadcast to all clients
self.broadcast_message(message=json.dumps(_message_to_dict(response)))
self.broadcast_message(message=agent_message)


def on_close(self):
self.log.debug("Disconnecting client with user %s", self.current_user.username)
self.remove_chat_client(self.current_user.username)
self.log.debug("Disconnecting client with user %s", self.client_id)

self.chat_handlers.pop(self.client_id, None)
self.chat_clients.pop(self.client_id, None)

self.log.info(f"Client disconnected. ID: {self.client_id}")
self.log.debug("Chat clients: %s", self.chat_handlers.keys())
51 changes: 41 additions & 10 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,52 @@
from pydantic import BaseModel, validator
from typing import Dict, List, Literal

from langchain.schema import BaseMessage, _message_to_dict
from pydantic import BaseModel
from typing import Dict, List, Union, Literal, Optional

class PromptRequest(BaseModel):
task_id: str
engine_id: str
prompt_variables: Dict[str, str]

# the type of message used to chat with the agent
class ChatRequest(BaseModel):
prompt: str

class ChatClient(BaseModel):
id: str
initials: str
name: str
display_name: str
color: Optional[str]
avatar_url: Optional[str]

class AgentChatMessage(BaseModel):
type: Literal["agent"] = "agent"
id: str
body: str
# message ID of the HumanChatMessage it is replying to
reply_to: str

class HumanChatMessage(BaseModel):
type: Literal["human"] = "human"
id: str
body: str
client: ChatClient

class ConnectionMessage(BaseModel):
type: Literal["connection"] = "connection"
client_id: str

# the type of messages being broadcast to clients
ChatMessage = Union[
AgentChatMessage,
HumanChatMessage,
]

Message = Union[
AgentChatMessage,
HumanChatMessage,
ConnectionMessage
]

class ListEnginesEntry(BaseModel):
id: str
name: str
Expand All @@ -30,9 +66,4 @@ class DescribeTaskResponse(BaseModel):

class ChatHistory(BaseModel):
"""History of chat messages"""
messages: List[BaseMessage]

class Config:
json_encoders = {
BaseMessage: lambda v: _message_to_dict(v)
}
messages: List[ChatMessage]
Loading

0 comments on commit 3be6ba0

Please sign in to comment.