Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(python): Refactor show_graph #5059

Merged
merged 2 commits into from
Oct 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 37 additions & 46 deletions py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from __future__ import annotations

import os
import shutil
import subprocess
import tempfile
import typing
from io import BytesIO, IOBase, StringIO
from pathlib import Path
Expand Down Expand Up @@ -609,7 +606,7 @@ def show_graph(
output_path
Write the figure to disk.
raw_output
Return dot syntax. This cannot be combined with `show`
Return dot syntax. This cannot be combined with `show` and/or `output_path`.
figsize
Passed to matplotlib if `show` == True.
type_coercion
Expand All @@ -626,9 +623,6 @@ def show_graph(
Will try to cache branching subplans that occur on self-joins or unions.

"""
if raw_output:
show = False

_ldf = self._ldf.optimization_toggle(
type_coercion,
predicate_pushdown,
Expand All @@ -638,48 +632,45 @@ def show_graph(
common_subplan_elimination,
)

if show and _in_notebook():
try:
from IPython.display import SVG, display

dot = _ldf.to_dot(optimized)
svg = subprocess.check_output(
["dot", "-Nshape=box", "-Tsvg"], input=f"{dot}".encode()
)
return display(SVG(svg))
except Exception as exc:
raise ImportError(
"Graphviz dot binary should be on your PATH and matplotlib should"
" be installed to show graph."
) from exc
try:
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
except ImportError:
raise ImportError(
"Graphviz dot binary should be on your PATH and matplotlib should be"
" installed to show graph."
) from None
dot = _ldf.to_dot(optimized)

if raw_output:
# we do not show a graph, nor save a graph to disk
return dot
with tempfile.TemporaryDirectory() as tmpdir_name:
dot_path = os.path.join(tmpdir_name, "dot")
with open(dot_path, "w", encoding="utf8") as f:
f.write(dot)

subprocess.run(["dot", "-Nshape=box", "-Tpng", "-O", dot_path])
out_path = os.path.join(tmpdir_name, "dot.png")

if output_path is not None:
shutil.copy(out_path, output_path)

if show:
plt.figure(figsize=figsize)
img = mpimg.imread(out_path)
plt.imshow(img)
plt.show()
return None

output_type = "svg" if _in_notebook() else "png"

try:
graph = subprocess.check_output(
["dot", "-Nshape=box", "-T" + output_type], input=f"{dot}".encode()
)
except ImportError:
raise ImportError("Graphviz dot binary should be on your PATH") from None

if output_path:
with Path(output_path).open(mode="wb") as file:
file.write(graph)

if not show:
return None

if _in_notebook():
from IPython.display import SVG, display

return display(SVG(graph))
else:
try:
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
except ImportError:
raise ImportError(
"matplotlib should be installed to show graph."
) from None
plt.figure(figsize=figsize)
img = mpimg.imread(BytesIO(graph))
plt.imshow(img)
plt.show()
return None

def inspect(self: LDF, fmt: str = "{}") -> LDF:
"""
Expand Down
1 change: 1 addition & 0 deletions py-polars/polars/show_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _get_dependency_info() -> dict[str, str]:
"fsspec",
"connectorx",
"xlsx2csv",
"matplotlib",
]
return {name: _get_dep_version(name) for name in opt_deps}

Expand Down
3 changes: 2 additions & 1 deletion py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ fsspec = ["fsspec"]
connectorx = ["connectorx"]
xlsx2csv = ["xlsx2csv >= 0.8.0"]
timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_system == 'Windows'"]
matplotlib = ["matplotlib"]
all = [
"polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,timezone]",
"polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,timezone,matplotlib]",
]

[tool.isort]
Expand Down