Skip to content

Commit

Permalink
Added test for on_all_streams_end call
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Apr 27, 2024
1 parent 57524ba commit 4ebf8a8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
6 changes: 5 additions & 1 deletion agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_completion_stream(self,
if not inspect.isclass(event_handler):
raise Exception("Event handler must not be an instance.")

return self.main_thread.get_completion_stream(message=message,
res = self.main_thread.get_completion_stream(message=message,
message_files=message_files,
event_handler=event_handler,
attachments=attachments,
Expand All @@ -181,6 +181,10 @@ def get_completion_stream(self,
tool_choice=tool_choice
)

event_handler.on_all_streams_end()

return res

def demo_gradio(self, height=450, dark_mode=True, **kwargs):
"""
Launches a Gradio-based demo interface for the agency chatbot.
Expand Down
3 changes: 0 additions & 3 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,6 @@ def get_completion(self,

continue

if event_handler:
event_handler.on_all_streams_end()

return full_message

def _create_run(self, recipient_agent, additional_instructions, event_handler, tool_choice):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def test_5_agent_communication_stream(self):

test_tool_used = False
test_agent2_used = False
num_on_all_streams_end_calls = 0

class EventHandler(AgencyEventHandler):
@override
Expand All @@ -273,6 +274,12 @@ def on_tool_call_done(self, tool_call: ToolCall) -> None:
nonlocal test_tool_used
test_tool_used = True

@override
@classmethod
def on_all_streams_end(cls):
nonlocal num_on_all_streams_end_calls
num_on_all_streams_end_calls += 1

message = self.__class__.agency.get_completion_stream(
"Please tell TestAgent1 to tell TestAgent 2 to use test tool.",
event_handler=EventHandler,
Expand All @@ -284,6 +291,7 @@ def on_tool_call_done(self, tool_call: ToolCall) -> None:

self.assertTrue(test_tool_used)
self.assertTrue(test_agent2_used)
self.assertTrue(num_on_all_streams_end_calls == 1)

self.assertTrue(self.__class__.TestTool.shared_state.get("test_tool_used"))

Expand Down

0 comments on commit 4ebf8a8

Please sign in to comment.