diff --git a/py-polars/polars/internals/lazyframe/frame.py b/py-polars/polars/internals/lazyframe/frame.py index 0c99493f4a87..45d10f296901 100644 --- a/py-polars/polars/internals/lazyframe/frame.py +++ b/py-polars/polars/internals/lazyframe/frame.py @@ -1,9 +1,6 @@ from __future__ import annotations -import os -import shutil import subprocess -import tempfile import typing from io import BytesIO, IOBase, StringIO from pathlib import Path @@ -609,7 +606,7 @@ def show_graph( output_path Write the figure to disk. raw_output - Return dot syntax. This cannot be combined with `show` + Return dot syntax. This cannot be combined with `show` and/or `output_path`. figsize Passed to matplotlib if `show` == True. type_coercion @@ -626,9 +623,6 @@ def show_graph( Will try to cache branching subplans that occur on self-joins or unions. """ - if raw_output: - show = False - _ldf = self._ldf.optimization_toggle( type_coercion, predicate_pushdown, @@ -638,48 +632,45 @@ def show_graph( common_subplan_elimination, ) - if show and _in_notebook(): - try: - from IPython.display import SVG, display - - dot = _ldf.to_dot(optimized) - svg = subprocess.check_output( - ["dot", "-Nshape=box", "-Tsvg"], input=f"{dot}".encode() - ) - return display(SVG(svg)) - except Exception as exc: - raise ImportError( - "Graphviz dot binary should be on your PATH and matplotlib should" - " be installed to show graph." - ) from exc - try: - import matplotlib.image as mpimg - import matplotlib.pyplot as plt - except ImportError: - raise ImportError( - "Graphviz dot binary should be on your PATH and matplotlib should be" - " installed to show graph." - ) from None dot = _ldf.to_dot(optimized) + if raw_output: + # we do not show a graph, nor save a graph to disk return dot - with tempfile.TemporaryDirectory() as tmpdir_name: - dot_path = os.path.join(tmpdir_name, "dot") - with open(dot_path, "w", encoding="utf8") as f: - f.write(dot) - - subprocess.run(["dot", "-Nshape=box", "-Tpng", "-O", dot_path]) - out_path = os.path.join(tmpdir_name, "dot.png") - - if output_path is not None: - shutil.copy(out_path, output_path) - - if show: - plt.figure(figsize=figsize) - img = mpimg.imread(out_path) - plt.imshow(img) - plt.show() - return None + + output_type = "svg" if _in_notebook() else "png" + + try: + graph = subprocess.check_output( + ["dot", "-Nshape=box", "-T" + output_type], input=f"{dot}".encode() + ) + except ImportError: + raise ImportError("Graphviz dot binary should be on your PATH") from None + + if output_path: + with Path(output_path).open(mode="wb") as file: + file.write(graph) + + if not show: + return None + + if _in_notebook(): + from IPython.display import SVG, display + + return display(SVG(graph)) + else: + try: + import matplotlib.image as mpimg + import matplotlib.pyplot as plt + except ImportError: + raise ImportError( + "matplotlib should be installed to show graph." + ) from None + plt.figure(figsize=figsize) + img = mpimg.imread(BytesIO(graph)) + plt.imshow(img) + plt.show() + return None def inspect(self: LDF, fmt: str = "{}") -> LDF: """ diff --git a/py-polars/polars/show_versions.py b/py-polars/polars/show_versions.py index e32992365942..1f13a94d865c 100644 --- a/py-polars/polars/show_versions.py +++ b/py-polars/polars/show_versions.py @@ -55,6 +55,7 @@ def _get_dependency_info() -> dict[str, str]: "fsspec", "connectorx", "xlsx2csv", + "matplotlib", ] return {name: _get_dep_version(name) for name in opt_deps} diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index accc9b41a783..f375afb1f3da 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -20,8 +20,9 @@ fsspec = ["fsspec"] connectorx = ["connectorx"] xlsx2csv = ["xlsx2csv >= 0.8.0"] timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_system == 'Windows'"] +matplotlib = ["matplotlib"] all = [ - "polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,timezone]", + "polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,timezone,matplotlib]", ] [tool.isort]