Skip to content

Commit

Permalink
community[patch]: avoid executing toolkit.get_context() when not ne…
Browse files Browse the repository at this point in the history
…cessary (langchain-ai#19762)

If `prompt` is passed into `create_sql_agent()`, then
`toolkit.get_context()` shouldn't be executed against the database
unless relevant prompt variables (`table_info` or `table_names`) are
present .
  • Loading branch information
x-arturs authored Mar 29, 2024
1 parent ec7a59c commit 2319212
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions libs/community/langchain_community/agent_toolkits/sql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,18 @@ def create_sql_agent(
prompt = prompt.partial(top_k=str(top_k))
if "dialect" in prompt.input_variables:
prompt = prompt.partial(dialect=toolkit.dialect)
db_context = toolkit.get_context()
if "table_info" in prompt.input_variables:
prompt = prompt.partial(table_info=db_context["table_info"])
tools = [
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
]
if "table_names" in prompt.input_variables:
prompt = prompt.partial(table_names=db_context["table_names"])
tools = [
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
]
if any(key in prompt.input_variables for key in ["table_info", "table_names"]):
db_context = toolkit.get_context()
if "table_info" in prompt.input_variables:
prompt = prompt.partial(table_info=db_context["table_info"])
tools = [
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
]
if "table_names" in prompt.input_variables:
prompt = prompt.partial(table_names=db_context["table_names"])
tools = [
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
]

if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
if prompt is None:
Expand Down

0 comments on commit 2319212

Please sign in to comment.