Skip to content

Commit

Permalink
feat: chrome extension integration (#217)
Browse files Browse the repository at this point in the history
Add WebAPI functionality required for the Chrome extension

* feat: Added Websocket echo endpoint.

* feat: Add chat history handler endpoint

* feat: Add token authentication for api endpoints

* feat: Add token refresh endpoint.

* feat: Add chat history endpoint handler.

* chore: Cleanup debug logs.

* refactor: Handle token refresh.

* feat: Add chat thread endpoint handlers.

* feat: Add chat history post request validation.

Co-authored-by: Janaka Abeywardhana <contact@janaka.co.uk>

* chore: Add spaces api endpoint handler.

* feat(api middleware): add support for named path arguments.
  * Tornado doesn't support named paths args e.g. `api/items/{item_id}/` which is the modern convention.

* chore: Handle rag history and threads.

* chore: Register Spaces and file_upload handlers.

* chore: Add top questions handler.

* chore: Add summary questions endpoint handler.

* chore: Fix RAG (use saved collection settings if none is providded).

- example pattern to follow in all handlers that are operating directly on an domain entity.

* !refactor(API): various changes to routes design
- prefix all routes with v1
- adjust the domain entity route design to be plural, path args, and query string only for filters. trying to follow REST

* chore: Format token handler.

* chore: Enforce path arguments type annotation.

* chore: Set API key as a custom header.

* chore: Add token validation and refresh handlers.


---------

Co-authored-by: Janaka Abeywardhana <contact@janaka.co.uk>
  • Loading branch information
osala-eng and janaka committed Mar 8, 2024
1 parent 51a87fd commit 2cc40d9
Show file tree
Hide file tree
Showing 21 changed files with 960 additions and 226 deletions.
18 changes: 16 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ google-cloud-aiplatform = "^1.38.0"
litellm = "^1.26.6"
semantic-kernel = "^0.4.3.dev0"
imap-tools = "^1.5.0"
jwt = "^1.3.1"
llama-index = "0.9.8.post1"

[tool.poetry.group.dev.dependencies]
Expand Down
14 changes: 7 additions & 7 deletions source/docq/manage_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,20 @@ def llama_index_chat_prompt_template_from_persona(persona: Assistant) -> ChatPro
return ChatPromptTemplate(message_templates=[_system_prompt_message, _user_prompt_message])


def get_personas_fixed(assistant_type: Optional[AssistantType] = None) -> dict[str, Assistant]:
def get_personas_fixed(llm_settings_collection_key: str, assistant_type: Optional[AssistantType] = None) -> dict[str, Assistant]:
"""Get the personas."""
result = {}
if assistant_type == AssistantType.SIMPLE_CHAT:
result = {key: Assistant(key=key, **persona) for key, persona in SIMPLE_CHAT_PERSONAS.items()}
result = {key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in SIMPLE_CHAT_PERSONAS.items()}
elif assistant_type == AssistantType.AGENT:
result = {key: Assistant(key=key, **persona) for key, persona in AGENT_PERSONAS.items()}
result = {key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in AGENT_PERSONAS.items()}
elif assistant_type == AssistantType.ASK:
result = {key: Assistant(key=key, **persona) for key, persona in ASK_PERSONAS.items()}
result = {key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in ASK_PERSONAS.items()}
else:
result = {
**{key: Assistant(key=key, **persona) for key, persona in SIMPLE_CHAT_PERSONAS.items()},
**{key: Assistant(key=key, **persona) for key, persona in AGENT_PERSONAS.items()},
**{key: Assistant(key=key, **persona) for key, persona in ASK_PERSONAS.items()},
**{key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in SIMPLE_CHAT_PERSONAS.items()},
**{key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in AGENT_PERSONAS.items()},
**{key: Assistant(key=key, **persona, llm_settings_collection_key=llm_settings_collection_key) for key, persona in ASK_PERSONAS.items()},
}
return result

Expand Down
14 changes: 14 additions & 0 deletions source/docq/manage_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,17 @@ def update_shared_space_permissions(id_: int, accessors: List[SpaceAccessor]) ->
)
connection.commit()
return True


def get_space(space_id: int, org_id: int) -> Optional[SPACE]:
"""Get a space."""
log.debug("get_space(): Getting space with id=%d", space_id)
with closing(
sqlite3.connect(get_sqlite_shared_system_file(), detect_types=sqlite3.PARSE_DECLTYPES)
) as connection, closing(connection.cursor()) as cursor:
cursor.execute(
"SELECT id, org_id, name, summary, archived, datasource_type, datasource_configs, space_type, created_at, updated_at FROM spaces WHERE id = ? AND org_id = ?",
(space_id, org_id),
)
row = cursor.fetchone()
return _format_space(row) if row else None
7 changes: 5 additions & 2 deletions source/docq/run_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _retrieve_messages(
return rows


def list_thread_history(feature: FeatureKey) -> list[tuple[int, str, int]]:
def list_thread_history(feature: FeatureKey, id_: Optional[int] = None) -> list[tuple[int, str, int]]:
"""List the history of threads."""
tablename = get_history_thread_table_name(feature.type_)
rows = None
Expand All @@ -119,7 +119,10 @@ def list_thread_history(feature: FeatureKey) -> list[tuple[int, str, int]]:
table=tablename,
)
)
rows = cursor.execute(f"SELECT id, topic, created_at FROM {tablename} ORDER BY created_at DESC").fetchall() # noqa: S608
if id_:
rows = cursor.execute(f"SELECT id, topic, created_at FROM {tablename} WHERE id = ?", (id_,)).fetchall() # noqa: S608
else:
rows = cursor.execute(f"SELECT id, topic, created_at FROM {tablename} ORDER BY created_at DESC").fetchall() # noqa: S608

return rows

Expand Down
10 changes: 10 additions & 0 deletions web/api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Docq.AI RESTful API

## Introduction

This is a RESTful API that provides access to Docq.AI SaaS.
[Postman Collection](https://www.postman.com/spacecraft-physicist-48460084/workspace/docq-api/collection/22287507-cae373c0-bdf6-4efe-9594-f2d8fd10f924?action=share&creator=22287507)

## Authentication

The API uses JWT for authentication. You can obtain a token by sending a POST request to the `/api/{version}/token` endpoint with your username and password.
64 changes: 64 additions & 0 deletions web/api/base_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Base request handlers."""
from typing import Any, Optional, Self

import docq.manage_organisations as m_orgs
from opentelemetry import trace
from pydantic import ValidationError
from tornado.web import HTTPError, RequestHandler

from web.api.models import UserModel
from web.utils.handlers import _default_org_id as get_default_org_id

tracer = trace.get_tracer(__name__)


class BaseRequestHandler(RequestHandler):
"""Base request Handler."""

__selected_org_id: Optional[int] = None

def check_origin(self: Self, origin: Any) -> bool:
"""Override the origin check if it's causing problems."""
return True

def check_xsrf_cookie(self: Self) -> bool:
"""Override the XSRF cookie check."""
# If `True`, POST, PUT, and DELETE are block unless the `_xsrf` cookie is set.
# Safe with token based authN
return False

@property
def selected_org_id(self: Self) -> int:
"""Get the selected org id."""
if self.__selected_org_id is None:
u = self.current_user
member_orgs = m_orgs.list_organisations(user_id=u.uid)
self.__selected_org_id = get_default_org_id(member_orgs, (u.uid, u.fullname, u.super_admin, u.username))
return self.__selected_org_id

@tracer.start_as_current_span("get_current_user")
def get_current_user(self: Self) -> UserModel:
"""Retrieve user data from token."""
span = trace.get_current_span()

auth_header = self.request.headers.get("Authorization")
if not auth_header:
error_msg = "Missing Authorization header"
span.set_status(trace.Status(trace.StatusCode.ERROR))
span.record_exception(ValueError(error_msg))
raise HTTPError(401, reason=error_msg, log_message=error_msg)

scheme, token = auth_header.split(" ")
if scheme.lower() != "bearer":
span.set_status(trace.Status(trace.StatusCode.ERROR))
span.record_exception(ValueError("Authorization scheme must be Bearer"))
raise HTTPError(401, reason="Authorization scheme must be Bearer")

try:
from web.api.utils.auth_utils import decode_jwt

payload = decode_jwt(token)
user = UserModel.model_validate(payload.get("data"))
return user
except ValidationError as e:
raise HTTPError(401, reason="Unauthorized: Validation error") from e
74 changes: 35 additions & 39 deletions web/api/chat_completion_handler.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,66 @@
"""Handle /api/chat/completion requests."""
from typing import Any, Optional, Self
from typing import Optional, Self

from docq.manage_personas import get_persona
import docq.run_queries as rq
from docq.manage_assistants import get_personas_fixed
from docq.model_selection.main import get_model_settings_collection
from docq.run_queries import run_chat
from pydantic import Field, ValidationError
from tornado.web import HTTPError, RequestHandler
from tornado.web import HTTPError

from web.api.utils import CamelModel, authenticated
from web.api.base_handlers import BaseRequestHandler
from web.api.models import MessageResponseModel
from web.api.utils.auth_utils import authenticated
from web.api.utils.docq_utils import get_feature_key, get_message_object
from web.api.utils.pydantic_utils import CamelModel
from web.utils.streamlit_application import st_app


class PostRequestModel(CamelModel):
"""Pydantic model for the request body."""

input_: str = Field(..., alias="input")
thread_id: int
history: Optional[str] = Field(None)
llm_settings_collection_name: Optional[str] = Field(None)
persona_key: Optional[str] = Field(None)
assistant_key: Optional[str] = Field(None)

class PostResponseModel(CamelModel):
"""Pydantic model for the response body."""
response: str
meta: Optional[dict[str,str]] = None

@st_app.api_route("/api/chat/completion")
class ChatCompletionHandler(RequestHandler):
@st_app.api_route("/api/v1/chat/completion")
class ChatCompletionHandler(BaseRequestHandler):
"""Handle /api/chat/completion requests."""

def check_origin(self: Self, origin: Any) -> bool:
"""Override the origin check if it's causing problems."""
return True

def check_xsrf_cookie(self: Self) -> bool:
"""Override the XSRF cookie check."""
# If `True`, POST, PUT, and DELETE are block unless the `_xsrf` cookie is set.
# Safe with token based authN
return False

def get(self: Self) -> None:
"""Handle GET request."""
self.write({"message": "hello world 2"})



@authenticated
def post(self: Self) -> None:
"""Handle POST request.
Example:
```shell
```sh
curl -X POST -H "Content-Type: application/json" -H "Authorization: Bearer expected_token" -d /
'{"input":"what's the sun?", "modelSettingsCollectionName"}' http://localhost:8501/api/chat/completion
'{"input":"what is the sun?", "modelSettingsCollectionName"}' http://localhost:8501/api/v1/chat/completion
```
"""
body = self.request.body

feature = get_feature_key(self.current_user.uid, "chat")
try:
request_model = PostRequestModel.model_validate_json(body)
history = request_model.history if request_model.history else ""
model_usage_settings = get_model_settings_collection(request_model.llm_settings_collection_name) if request_model.llm_settings_collection_name else get_model_settings_collection("azure_openai_latest")
persona = get_persona(request_model.persona_key if request_model.persona_key else "default")
result = run_chat(input_=request_model.input_, history=history, model_settings_collection=model_usage_settings, persona=persona)
response_model = PostResponseModel(response=result.response, meta={"model_settings": model_usage_settings.key})
request = PostRequestModel.model_validate_json(self.request.body)
llm_settings_collection_name = request.llm_settings_collection_name or "azure_openai_latest"
model_usage_settings = get_model_settings_collection(llm_settings_collection_name)
assistant_key = request.assistant_key if request.assistant_key else "default"
assistant = get_personas_fixed(model_usage_settings.key)[assistant_key]
if not assistant:
raise HTTPError(400, reason="Invalid persona key")
thread_id = request.thread_id

result = rq.query(
input_=request.input_,
feature=feature,
thread_id=thread_id,
model_settings_collection=model_usage_settings,
persona=assistant,
)
messages = list(map(get_message_object, result))
response_model = MessageResponseModel(response=messages, meta={"model_settings": model_usage_settings.key})

self.write(response_model.model_dump())

except ValidationError as e:
raise HTTPError(status_code=400, reason="Invalid request body", log_message=str(e)) from e

8 changes: 3 additions & 5 deletions web/api/hello_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@

from tornado.web import RequestHandler

from web.api.utils import CamelModel
from web.api.utils.pydantic_utils import CamelModel
from web.utils.streamlit_application import st_app


class ResponseModel(CamelModel):
"""Pydantic model for the response body."""
response: str

@st_app.api_route("/api/hello")
@st_app.api_route("/api/v1/hello")
class ChatCompletionHandler(RequestHandler):
"""Handle /api/hello requests."""
"""Handle /api/v1/hello requests."""

def check_origin(self: Self, origin) -> bool:
"""Override the origin check if it's causing problems."""
Expand All @@ -28,5 +28,3 @@ def get(self: Self) -> None:
"""Handle GET request."""
response = ResponseModel(response="Hello World!")
self.write(response.model_dump())


3 changes: 3 additions & 0 deletions web/api/index_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,7 @@ class name: route replace capitalise route segments remove `/` and `_`. Example:
chat_completion_handler, # noqa: F401 DO NOT REMOVE
hello_handler, # noqa: F401 DO NOT REMOVE
rag_completion_handler, # noqa: F401 DO NOT REMOVE
spaces_handler, # noqa: F401 DO NOT REMOVE
threads_handler, # noqa: F401 DO NOT REMOVE
token_handler, # noqa: F401 DO NOT REMOVE
)
53 changes: 53 additions & 0 deletions web/api/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""API models."""

from typing import Literal, Optional

from pydantic import BaseModel, Field

SPACE_TYPE = Literal["personal", "shared", "public", "thread"]
FEATURE = Literal["rag", "chat"]

class UserModel(BaseModel):
"""Pydantic model for user data."""

uid: int
fullname: str
super_admin: bool
username: str

class MessageModel(BaseModel):
"""Pydantic model for message data."""
id_: int = Field(..., alias="id")
message: str
human: bool
timestamp: str
thread_id: int

class MessageResponseModel(BaseModel):
"""Pydantic model for the response body."""
response: list[MessageModel]
meta: Optional[dict[str,str]] = None

class ChatHistoryModel(BaseModel):
"""Pydantic model for chat history."""
response : list[MessageModel]

class ThreadModel(BaseModel):
"""Pydantic model for the response body."""
id_: int = Field(..., alias="id")
topic: str
created_at: str

class SpaceModel(BaseModel):
"""Pydantic model for the response body."""
id_: int = Field(..., alias="id")
space_type: SPACE_TYPE
created_at: str

class ThreadResponseModel(BaseModel):
"""Pydantic model for the response body."""
response: list[ThreadModel]

class ThreadPostRequestModel(BaseModel):
"""Pydantic model for the request body."""
topic: str
Loading

0 comments on commit 2cc40d9

Please sign in to comment.