Skip to content

Commit

Permalink
Add on_chat_message_start (langchain-ai#4499)
Browse files Browse the repository at this point in the history
### Add on_chat_message_start to callback manager and base tracer

Goal: trace messages directly to permit reloading as chat messages
(store in an integration-agnostic way)

Add an `on_chat_message_start` method. Fall back to `on_llm_start()` for
handlers that don't have it implemented.

Does so in a non-backwards-compat breaking way (for now)
  • Loading branch information
vowelparrot committed May 11, 2023
1 parent bbf76db commit 4ee4792
Show file tree
Hide file tree
Showing 12 changed files with 311 additions and 140 deletions.
40 changes: 39 additions & 1 deletion langchain/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
LLMResult,
)


class LLMManagerMixin:
Expand Down Expand Up @@ -123,6 +128,20 @@ def on_llm_start(
) -> Any:
"""Run when LLM starts running."""

def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)

def on_chain_start(
self,
serialized: Dict[str, Any],
Expand Down Expand Up @@ -184,6 +203,11 @@ def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return False

@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return False


class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
Expand All @@ -199,6 +223,20 @@ async def on_llm_start(
) -> None:
"""Run when LLM starts running."""

async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)

async def on_llm_new_token(
self,
token: str,
Expand Down
99 changes: 94 additions & 5 deletions langchain/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import functools
import logging
import os
import warnings
from contextlib import contextmanager
Expand All @@ -22,8 +23,15 @@
from langchain.callbacks.tracers.base import TracerSession
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
from langchain.callbacks.tracers.schemas import TracerSessionV2
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
LLMResult,
get_buffer_string,
)

logger = logging.getLogger(__name__)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]

openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
Expand Down Expand Up @@ -87,15 +95,31 @@ def _handle_event(
*args: Any,
**kwargs: Any,
) -> None:
"""Generic event handler for CallbackManager."""
message_strings: Optional[List[str]] = None
for handler in handlers:
try:
if ignore_condition_name is None or not getattr(
handler, ignore_condition_name
):
getattr(handler, event_name)(*args, **kwargs)
except NotImplementedError as e:
if event_name == "on_chat_model_start":
if message_strings is None:
message_strings = [get_buffer_string(m) for m in args[1]]
_handle_event(
[handler],
"on_llm_start",
"ignore_llm",
args[0],
message_strings,
*args[2:],
**kwargs,
)
else:
logger.warning(f"Error in {event_name} callback: {e}")
except Exception as e:
# TODO: switch this to use logging
print(f"Error in {event_name} callback: {e}")
logging.warning(f"Error in {event_name} callback: {e}")


async def _ahandle_event_for_handler(
Expand All @@ -114,9 +138,22 @@ async def _ahandle_event_for_handler(
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(event, *args, **kwargs)
)
except NotImplementedError as e:
if event_name == "on_chat_model_start":
message_strings = [get_buffer_string(m) for m in args[1]]
await _ahandle_event_for_handler(
handler,
"on_llm",
"ignore_llm",
args[0],
message_strings,
*args[2:],
**kwargs,
)
else:
logger.warning(f"Error in {event_name} callback: {e}")
except Exception as e:
# TODO: switch this to use logging
print(f"Error in {event_name} callback: {e}")
logger.warning(f"Error in {event_name} callback: {e}")


async def _ahandle_event(
Expand Down Expand Up @@ -531,6 +568,33 @@ def on_llm_start(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)

def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)

# Re-use the LLM Run Manager since the outputs are treated
# the same for now
return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)

def on_chain_start(
self,
serialized: Dict[str, Any],
Expand Down Expand Up @@ -629,6 +693,31 @@ async def on_llm_start(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)

async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if run_id is None:
run_id = uuid4()

await _ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)

return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)

async def on_chain_start(
self,
serialized: Dict[str, Any],
Expand Down
39 changes: 38 additions & 1 deletion langchain/callbacks/tracers/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from uuid import UUID, uuid4

Expand All @@ -19,6 +20,7 @@
TracerSessionV2,
TracerSessionV2Create,
)
from langchain.schema import BaseMessage, messages_to_dict
from langchain.utils import raise_for_status_with_text


Expand Down Expand Up @@ -193,6 +195,36 @@ def load_default_session(self) -> TracerSessionV2:
"""Load the default tracing session and set it as the Tracer's session."""
return self.load_session("default")

def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
if self.session is None:
self.session = self.load_default_session()

run_id_ = str(run_id)
parent_run_id_ = str(parent_run_id) if parent_run_id else None

execution_order = self._get_execution_order(parent_run_id_)
llm_run = LLMRun(
uuid=run_id_,
parent_uuid=parent_run_id_,
serialized=serialized,
prompts=[],
extra={**kwargs, "messages": messages},
start_time=datetime.utcnow(),
execution_order=execution_order,
child_execution_order=execution_order,
session_id=self.session.id,
)
self._start_trace(llm_run)

def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
"""Convert a run to a Run."""
session = self.session or self.load_default_session()
Expand All @@ -201,7 +233,12 @@ def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
if isinstance(run, LLMRun):
run_type = "llm"
inputs = {"prompts": run.prompts}
if run.extra is not None and "messages" in run.extra:
messages: List[List[BaseMessage]] = run.extra.pop("messages")
converted_messages = [messages_to_dict(batch) for batch in messages]
inputs = {"messages": converted_messages}
else:
inputs = {"prompts": run.prompts}
outputs = run.response.dict() if run.response else {}
child_runs = []
elif isinstance(run, ChainRun):
Expand Down
2 changes: 1 addition & 1 deletion langchain/callbacks/tracers/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class RunBase(BaseModel):
session_id: UUID
reference_example_id: Optional[UUID]
run_type: RunTypeEnum
parent_run_id: Optional[UUID]


class RunCreate(RunBase):
Expand All @@ -130,7 +131,6 @@ class Run(RunBase):
"""Run schema when loading from the DB."""

name: str
parent_run_id: Optional[UUID]


ChainRun.update_forward_refs()
Expand Down
11 changes: 4 additions & 7 deletions langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
HumanMessage,
LLMResult,
PromptValue,
get_buffer_string,
)


Expand Down Expand Up @@ -69,9 +68,8 @@ def generate(
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
message_strings = [get_buffer_string(m) for m in messages]
run_manager = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, message_strings
run_manager = callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages
)

new_arg_supported = inspect.signature(self._generate).parameters.get(
Expand Down Expand Up @@ -104,9 +102,8 @@ async def agenerate(
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
message_strings = [get_buffer_string(m) for m in messages]
run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, message_strings
run_manager = await callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages
)

new_arg_supported = inspect.signature(self._agenerate).parameters.get(
Expand Down
26 changes: 15 additions & 11 deletions langchain/client/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@
from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
from langchain.client.utils import parse_chat_messages
from langchain.llms.base import BaseLLM
from langchain.schema import ChatResult, LLMResult
from langchain.schema import ChatResult, LLMResult, messages_from_dict
from langchain.utils import raise_for_status_with_text, xor_args

if TYPE_CHECKING:
Expand Down Expand Up @@ -96,7 +95,6 @@ def _get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
"Unable to get seeded tenant ID. Please manually provide."
) from e
results: List[dict] = response.json()
breakpoint()
if len(results) == 0:
raise ValueError("No seeded tenant found")
return results[0]["id"]
Expand Down Expand Up @@ -296,13 +294,15 @@ async def _arun_llm(
langchain_tracer: LangChainTracerV2,
) -> Union[LLMResult, ChatResult]:
if isinstance(llm, BaseLLM):
if "prompts" not in inputs:
raise ValueError(f"LLM Run requires 'prompts' input. Got {inputs}")
llm_prompts: List[str] = inputs["prompts"]
llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer])
elif isinstance(llm, BaseChatModel):
chat_prompts: List[str] = inputs["prompts"]
messages = [
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
]
if "messages" not in inputs:
raise ValueError(f"Chat Run requires 'messages' input. Got {inputs}")
raw_messages: List[List[dict]] = inputs["messages"]
messages = [messages_from_dict(batch) for batch in raw_messages]
llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer])
else:
raise ValueError(f"Unsupported LLM type {type(llm)}")
Expand Down Expand Up @@ -454,13 +454,17 @@ def run_llm(
) -> Union[LLMResult, ChatResult]:
"""Run the language model on the example."""
if isinstance(llm, BaseLLM):
if "prompts" not in inputs:
raise ValueError(f"LLM Run must contain 'prompts' key. Got {inputs}")
llm_prompts: List[str] = inputs["prompts"]
llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer])
elif isinstance(llm, BaseChatModel):
chat_prompts: List[str] = inputs["prompts"]
messages = [
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
]
if "messages" not in inputs:
raise ValueError(
f"Chat Model Run must contain 'messages' key. Got {inputs}"
)
raw_messages: List[List[dict]] = inputs["messages"]
messages = [messages_from_dict(batch) for batch in raw_messages]
llm_output = llm.generate(messages, callbacks=[langchain_tracer])
else:
raise ValueError(f"Unsupported LLM type {type(llm)}")
Expand Down
Loading

0 comments on commit 4ee4792

Please sign in to comment.