Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback Handler for MLflow #4150

Merged
merged 8 commits into from
May 11, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fmt
  • Loading branch information
dev2049 committed May 11, 2023
commit 2b7deb5beb9b9dcc3503fd4dac2aeef1b06e751c
39 changes: 19 additions & 20 deletions langchain/callbacks/mlflow_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def import_mlflow() -> Any:
raise ImportError(
"To use the mlflow callback manager you need to have the `mlflow` python "
"package installed. Please install it with `pip install mlflow>=2.3.0`"

)
return mlflow

Expand All @@ -45,7 +44,7 @@ def analyze_text(
(dict): A dictionary containing the complexity metrics and visualization
files serialized to HTML string.
"""
resp: dict[str, Any] = {}
resp: Dict[str, Any] = {}
textstat = import_textstat()
spacy = import_spacy()
text_complexity_metrics = {
Expand Down Expand Up @@ -146,7 +145,7 @@ def __init__(self, **kwargs: Any):

self.start_run(kwargs["run_name"], kwargs["run_tags"])

def start_run(self, name: str, tags: dict[str, str]) -> None:
def start_run(self, name: str, tags: Dict[str, str]) -> None:
"""To start a new run, auto generates the random suffix for name"""
if name.endswith("-%"):
rname = "".join(random.choices(string.ascii_uppercase + string.digits, k=7))
Expand All @@ -170,15 +169,15 @@ def metric(self, key: str, value: float) -> None:
self.mlflow.log_metric(key, value)

def metrics(
self, data: Union[dict[str, float], dict[str, int]], step: Optional[int] = 0
self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0
) -> None:
"""To log all metrics in the input dict."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_metrics(data)

def jsonf(self, data: dict[str, Any], filename: str) -> None:
def jsonf(self, data: Dict[str, Any], filename: str) -> None:
"""To log the input data as json file artifact."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
Expand Down Expand Up @@ -279,7 +278,7 @@ def __init__(
"agent_ends": 0,
}

self.records: dict[str, Any] = {
self.records: Dict[str, Any] = {
"on_llm_start_records": [],
"on_llm_token_records": [],
"on_llm_end_records": [],
Expand Down Expand Up @@ -309,7 +308,7 @@ def on_llm_start(

llm_starts = self.metrics["llm_starts"]

resp: dict[str, Any] = {}
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
Expand All @@ -330,7 +329,7 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:

llm_streams = self.metrics["llm_streams"]

resp: dict[str, Any] = {}
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.metrics)

Expand All @@ -348,7 +347,7 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:

llm_ends = self.metrics["llm_ends"]

resp: dict[str, Any] = {}
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.metrics)
Expand All @@ -365,7 +364,7 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
nlp=self.nlp,
)
)
complexity_metrics: dict[str, float] = generation_resp.pop("text_complexity_metrics") # type: ignore
complexity_metrics: Dict[str, float] = generation_resp.pop("text_complexity_metrics") # type: ignore # noqa: E501
self.mlflg.metrics(
complexity_metrics,
step=self.metrics["step"],
Expand Down Expand Up @@ -395,7 +394,7 @@ def on_chain_start(

chain_starts = self.metrics["chain_starts"]

resp: dict[str, Any] = {}
resp: Dict[str, Any] = {}
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
Expand All @@ -417,7 +416,7 @@ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:

chain_ends = self.metrics["chain_ends"]

resp: dict[str, Any] = {}
resp: Dict[str, Any] = {}
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.metrics)
Expand Down Expand Up @@ -445,7 +444,7 @@ def on_tool_start(

tool_starts = self.metrics["tool_starts"]

resp: dict[str, Any] = {}
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
Expand All @@ -464,7 +463,7 @@ def on_tool_end(self, output: str, **kwargs: Any) -> None:

tool_ends = self.metrics["tool_ends"]

resp: dict[str, Any] = {}
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.metrics)

Expand All @@ -490,7 +489,7 @@ def on_text(self, text: str, **kwargs: Any) -> None:

text_ctr = self.metrics["text_ctr"]

resp: dict[str, Any] = {}
resp: Dict[str, Any] = {}
resp.update({"action": "on_text", "text": text})
resp.update(self.metrics)

Expand All @@ -507,7 +506,7 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
self.metrics["ends"] += 1

agent_ends = self.metrics["agent_ends"]
resp: dict[str, Any] = {}
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_finish",
Expand All @@ -530,7 +529,7 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
self.metrics["starts"] += 1

tool_starts = self.metrics["tool_starts"]
resp: dict[str, Any] = {}
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_action",
Expand Down Expand Up @@ -629,15 +628,15 @@ def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> No
try:
langchain_asset.save_agent(langchain_asset_path)
self.mlflg.artifact(langchain_asset_path)
except AttributeError as ae:
except AttributeError:
print("Could not save model.")
traceback.print_exc()
pass
except NotImplementedError as e:
except NotImplementedError:
print("Could not save model.")
traceback.print_exc()
pass
except NotImplementedError as e:
except NotImplementedError:
print("Could not save model.")
traceback.print_exc()
pass
Expand Down