Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic collaborative chat #58

Merged
merged 6 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,6 @@ playground/
# reserve path for a dev script
dev.sh

.vscode
.vscode

.jupyter_ystore.db
7 changes: 4 additions & 3 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import queue
from jupyter_server.extension.application import ExtensionApp
from langchain import ConversationChain
from .handlers import ChatHandler, ChatHistoryHandler, PromptAPIHandler, TaskAPIHandler, ChatAPIHandler
from .handlers import ChatHandler, ChatHistoryHandler, PromptAPIHandler, TaskAPIHandler
from importlib_metadata import entry_points
import inspect
from .engine import BaseModelEngine
Expand All @@ -20,7 +20,6 @@ class AiExtension(ExtensionApp):
name = "jupyter_ai"
handlers = [
("api/ai/prompt", PromptAPIHandler),
(r"api/ai/chat/?", ChatAPIHandler),
(r"api/ai/tasks/?", TaskAPIHandler),
(r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler),
(r"api/ai/chats/?", ChatHandler),
Expand Down Expand Up @@ -113,5 +112,7 @@ def initialize_settings(self):

# Store chat clients in a dictionary
self.settings["chat_clients"] = {}
self.settings["chat_handlers"] = {}


# store chat messages in memory for now
self.settings["chat_history"] = []
155 changes: 92 additions & 63 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import asdict
import json
from typing import Optional

from typing import Dict, List
import tornado
import uuid

from tornado.web import HTTPError
from pydantic import ValidationError

Expand All @@ -12,8 +13,8 @@
from jupyter_server.utils import ensure_async

from .task_manager import TaskManager
from .models import ChatHistory, PromptRequest, ChatRequest
from langchain.schema import _message_to_dict, HumanMessage, AIMessage
from .models import ChatHistory, PromptRequest, ChatRequest, ChatMessage, AgentChatMessage, HumanChatMessage, ConnectionMessage, ChatClient
from langchain.schema import HumanMessage

class APIHandler(BaseAPIHandler):
@property
Expand Down Expand Up @@ -60,26 +61,6 @@ async def post(self):
"insertion_mode": task.insertion_mode
}))

class ChatAPIHandler(APIHandler):
@tornado.web.authenticated
async def post(self):
try:
request = ChatRequest(**self.get_json_body())
except ValidationError as e:
self.log.exception(e)
raise HTTPError(500, str(e)) from e

if not self.openai_chat:
raise HTTPError(500, "No chat models available.")

result = await ensure_async(self.openai_chat.agenerate([request.prompt]))
output = result.generations[0][0].text
self.openai_chat.append_exchange(request.prompt, output)

self.finish(json.dumps({
"output": output,
}))

class TaskAPIHandler(APIHandler):
@tornado.web.authenticated
async def get(self, id=None):
Expand All @@ -106,25 +87,24 @@ def chat_provider(self):
if self._chat_provider is None:
self._chat_provider = self.settings["chat_provider"]
return self._chat_provider

@property
def messages(self):
self._messages = self.chat_provider.memory.chat_memory.messages or []
return self._messages

@property
def chat_history(self):
return self.settings["chat_history"]

@chat_history.setter
def _chat_history_setter(self, new_history):
self.settings["chat_history"] = new_history

@tornado.web.authenticated
async def get(self):
messages = []
for message in self.messages:
messages.append(message)
history = ChatHistory(messages=messages)

self.finish(history.json(models_as_dict=False))
history = ChatHistory(messages=self.chat_history)
self.finish(history.json())

@tornado.web.authenticated
async def delete(self):
self.chat_provider.memory.chat_memory.clear()
self.messages = []
self.chat_history = []
self.set_status(204)
self.finish()

Expand Down Expand Up @@ -153,19 +133,32 @@ def chat_message_queue(self):
self._chat_message_queue = self.settings["chat_message_queue"]
return self._chat_message_queue

@property
def chat_handlers(self) -> Dict[str, 'ChatHandler']:
"""Dictionary mapping client IDs to their WebSocket handler
instances."""
return self.settings["chat_handlers"]

@property
def chat_clients(self) -> Dict[str, ChatClient]:
"""Dictionary mapping client IDs to their ChatClient objects that store
metadata."""
return self.settings["chat_clients"]

@property
def chat_client(self) -> ChatClient:
"""Returns ChatClient object associated with the current connection."""
return self.chat_clients[self.client_id]

@property
def chat_history(self) -> List[ChatMessage]:
return self.settings["chat_history"]

@property
def messages(self):
self._messages = self.chat_provider.memory.chat_memory.messages or []
return self._messages
dlqqq marked this conversation as resolved.
Show resolved Hide resolved

def add_chat_client(self, username):
self.settings["chat_clients"][username] = self
self.log.debug("Clients are : %s", self.settings["chat_clients"].keys())

def remove_chat_client(self, username):
self.settings["chat_clients"][username] = None
self.log.debug("Chat clients: %s", self.settings['chat_clients'].keys())

def initialize(self):
self.log.debug("Initializing websocket connection %s", self.request.path)

Expand All @@ -188,24 +181,45 @@ async def get(self, *args, **kwargs):
res = super().get(*args, **kwargs)
await res

def generate_client_id(self):
"""Generates a client ID to identify the current WS connection."""
# if collaborative mode is enabled, each client already has a UUID
# collaborative = self.config.get("LabApp", {}).get("collaborative", False)
# if collaborative:
# return self.current_user.username

# if collaborative mode is not enabled, each client is assigned a UUID
dlqqq marked this conversation as resolved.
Show resolved Hide resolved
return uuid.uuid4().hex

def open(self):
self.log.debug("Client with user %s connected...", self.current_user.username)
self.add_chat_client(self.current_user.username)
"""Handles opening of a WebSocket connection. Client ID can be retrieved
from `self.client_id`."""

client_id = self.generate_client_id()
chat_client_kwargs = {k: v for k, v in asdict(self.current_user).items() if k != "username"}

self.chat_handlers[client_id] = self
self.chat_clients[client_id] = ChatClient(**chat_client_kwargs, id=client_id)
self.client_id = client_id
self.write_message(ConnectionMessage(client_id=client_id).dict())

def broadcast_message(self, message: any, exclude_current_user: Optional[bool] = False):
"""Broadcasts message to all connected clients,
optionally excluding the current user
self.log.info(f"Client connected. ID: {client_id}")
self.log.debug("Clients are : %s", self.chat_handlers.keys())

def broadcast_message(self, message: ChatMessage):
"""Broadcasts message to all connected clients, optionally excluding the
current user. Appends message to `self.chat_history`.
"""

self.log.debug("Broadcasting message: %s to all clients...", message)
client_names = self.settings["chat_clients"].keys()
if exclude_current_user:
client_names = client_names - [self.current_user.username]
client_ids = self.chat_handlers.keys()

for username in client_names:
client = self.settings["chat_clients"][username]
for client_id in client_ids:
client = self.chat_handlers[client_id]
if client:
client.write_message(message)
client.write_message(message.dict())

self.chat_history.append(message)

async def on_message(self, message):
self.log.debug("Message recieved: %s", message)
Expand All @@ -217,24 +231,39 @@ async def on_message(self, message):
self.log.error(e)
return

# message sent to the agent instance
message = HumanMessage(
content=chat_request.prompt,
additional_kwargs=dict(user=asdict(self.current_user))
dlqqq marked this conversation as resolved.
Show resolved Hide resolved
)
data = json.dumps(_message_to_dict(message))
# message broadcast to chat clients
chat_message_id = str(uuid.uuid4())
chat_message = HumanChatMessage(
id=chat_message_id,
body=chat_request.prompt,
client=self.chat_client,
)

# broadcast the message to other clients
self.broadcast_message(message=data, exclude_current_user=True)
self.broadcast_message(message=chat_message)

# process the message
response = await ensure_async(self.chat_provider.apredict(input=message.content))

response = AIMessage(
content=response
agent_message = AgentChatMessage(
id=str(uuid.uuid4()),
body=response,
reply_to=chat_message_id
)

# broadcast to all clients
self.broadcast_message(message=json.dumps(_message_to_dict(response)))
self.broadcast_message(message=agent_message)


def on_close(self):
self.log.debug("Disconnecting client with user %s", self.current_user.username)
self.remove_chat_client(self.current_user.username)
self.log.debug("Disconnecting client with user %s", self.client_id)

self.chat_handlers.pop(self.client_id, None)
self.chat_clients.pop(self.client_id, None)

self.log.info(f"Client disconnected. ID: {self.client_id}")
self.log.debug("Chat clients: %s", self.chat_handlers.keys())
51 changes: 41 additions & 10 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,52 @@
from pydantic import BaseModel, validator
from typing import Dict, List, Literal

from langchain.schema import BaseMessage, _message_to_dict
from pydantic import BaseModel
from typing import Dict, List, Union, Literal, Optional

class PromptRequest(BaseModel):
task_id: str
engine_id: str
prompt_variables: Dict[str, str]

# the type of message used to chat with the agent
class ChatRequest(BaseModel):
prompt: str

class ChatClient(BaseModel):
id: str
initials: str
name: str
display_name: str
color: Optional[str]
avatar_url: Optional[str]

class AgentChatMessage(BaseModel):
type: Literal["agent"] = "agent"
id: str
body: str
# message ID of the HumanChatMessage it is replying to
reply_to: str

class HumanChatMessage(BaseModel):
type: Literal["human"] = "human"
id: str
body: str
client: ChatClient

class ConnectionMessage(BaseModel):
type: Literal["connection"] = "connection"
client_id: str

# the type of messages being broadcast to clients
ChatMessage = Union[
AgentChatMessage,
HumanChatMessage,
]

Message = Union[
AgentChatMessage,
HumanChatMessage,
ConnectionMessage
]

class ListEnginesEntry(BaseModel):
id: str
name: str
Expand All @@ -30,9 +66,4 @@ class DescribeTaskResponse(BaseModel):

class ChatHistory(BaseModel):
"""History of chat messages"""
messages: List[BaseMessage]

class Config:
json_encoders = {
BaseMessage: lambda v: _message_to_dict(v)
}
messages: List[ChatMessage]
Loading