From b6370b97822b279e1cac396e73f91a4ef9a3b9cf Mon Sep 17 00:00:00 2001 From: Jeroen van Zundert Date: Sat, 1 Oct 2022 11:04:22 +0200 Subject: [PATCH 1/2] Refactor show_graph Separate graph generation from display, so we can display the right error on missing imports. Closes #5042. --- py-polars/polars/internals/lazyframe/frame.py | 83 +++++++++---------- py-polars/polars/show_versions.py | 1 + py-polars/pyproject.toml | 3 +- py-polars/src/conversion.rs | 18 +--- 4 files changed, 41 insertions(+), 64 deletions(-) diff --git a/py-polars/polars/internals/lazyframe/frame.py b/py-polars/polars/internals/lazyframe/frame.py index 0c99493f4a87..45d10f296901 100644 --- a/py-polars/polars/internals/lazyframe/frame.py +++ b/py-polars/polars/internals/lazyframe/frame.py @@ -1,9 +1,6 @@ from __future__ import annotations -import os -import shutil import subprocess -import tempfile import typing from io import BytesIO, IOBase, StringIO from pathlib import Path @@ -609,7 +606,7 @@ def show_graph( output_path Write the figure to disk. raw_output - Return dot syntax. This cannot be combined with `show` + Return dot syntax. This cannot be combined with `show` and/or `output_path`. figsize Passed to matplotlib if `show` == True. type_coercion @@ -626,9 +623,6 @@ def show_graph( Will try to cache branching subplans that occur on self-joins or unions. """ - if raw_output: - show = False - _ldf = self._ldf.optimization_toggle( type_coercion, predicate_pushdown, @@ -638,48 +632,45 @@ def show_graph( common_subplan_elimination, ) - if show and _in_notebook(): - try: - from IPython.display import SVG, display - - dot = _ldf.to_dot(optimized) - svg = subprocess.check_output( - ["dot", "-Nshape=box", "-Tsvg"], input=f"{dot}".encode() - ) - return display(SVG(svg)) - except Exception as exc: - raise ImportError( - "Graphviz dot binary should be on your PATH and matplotlib should" - " be installed to show graph." - ) from exc - try: - import matplotlib.image as mpimg - import matplotlib.pyplot as plt - except ImportError: - raise ImportError( - "Graphviz dot binary should be on your PATH and matplotlib should be" - " installed to show graph." - ) from None dot = _ldf.to_dot(optimized) + if raw_output: + # we do not show a graph, nor save a graph to disk return dot - with tempfile.TemporaryDirectory() as tmpdir_name: - dot_path = os.path.join(tmpdir_name, "dot") - with open(dot_path, "w", encoding="utf8") as f: - f.write(dot) - - subprocess.run(["dot", "-Nshape=box", "-Tpng", "-O", dot_path]) - out_path = os.path.join(tmpdir_name, "dot.png") - - if output_path is not None: - shutil.copy(out_path, output_path) - - if show: - plt.figure(figsize=figsize) - img = mpimg.imread(out_path) - plt.imshow(img) - plt.show() - return None + + output_type = "svg" if _in_notebook() else "png" + + try: + graph = subprocess.check_output( + ["dot", "-Nshape=box", "-T" + output_type], input=f"{dot}".encode() + ) + except ImportError: + raise ImportError("Graphviz dot binary should be on your PATH") from None + + if output_path: + with Path(output_path).open(mode="wb") as file: + file.write(graph) + + if not show: + return None + + if _in_notebook(): + from IPython.display import SVG, display + + return display(SVG(graph)) + else: + try: + import matplotlib.image as mpimg + import matplotlib.pyplot as plt + except ImportError: + raise ImportError( + "matplotlib should be installed to show graph." + ) from None + plt.figure(figsize=figsize) + img = mpimg.imread(BytesIO(graph)) + plt.imshow(img) + plt.show() + return None def inspect(self: LDF, fmt: str = "{}") -> LDF: """ diff --git a/py-polars/polars/show_versions.py b/py-polars/polars/show_versions.py index e32992365942..1f13a94d865c 100644 --- a/py-polars/polars/show_versions.py +++ b/py-polars/polars/show_versions.py @@ -55,6 +55,7 @@ def _get_dependency_info() -> dict[str, str]: "fsspec", "connectorx", "xlsx2csv", + "matplotlib", ] return {name: _get_dep_version(name) for name in opt_deps} diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index accc9b41a783..f375afb1f3da 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -20,8 +20,9 @@ fsspec = ["fsspec"] connectorx = ["connectorx"] xlsx2csv = ["xlsx2csv >= 0.8.0"] timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_system == 'Windows'"] +matplotlib = ["matplotlib"] all = [ - "polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,timezone]", + "polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,timezone,matplotlib]", ] [tool.isort] diff --git a/py-polars/src/conversion.rs b/py-polars/src/conversion.rs index dfb63716f481..7c4fbc5a1f8f 100644 --- a/py-polars/src/conversion.rs +++ b/py-polars/src/conversion.rs @@ -482,8 +482,6 @@ impl<'s> FromPyObject<'s> for Wrap> { Ok(AnyValue::Utf8(v).into()) } else if ob.get_type().name()?.contains("datetime") { Python::with_gil(|py| { - // windows - #[cfg(target_arch = "windows")] { let kwargs = PyDict::new(py); kwargs.set_item("tzinfo", py.None())?; @@ -497,25 +495,11 @@ impl<'s> FromPyObject<'s> for Wrap> { .unwrap(); let loc_tz = localize.call1((dt, "UTC")); - loc_tz.call_method0("timestamp")?; + let ts = loc_tz.unwrap().call_method0("timestamp")?; // s to us let v = (ts.extract::()? * 1000_000.0) as i64; Ok(AnyValue::Datetime(v, TimeUnit::Microseconds, &None).into()) } - // unix - #[cfg(not(target_arch = "windows"))] - { - let datetime = PyModule::import(py, "datetime")?; - let timezone = datetime.getattr("timezone")?; - let kwargs = PyDict::new(py); - kwargs.set_item("tzinfo", timezone.getattr("utc")?)?; - let dt = ob.call_method("replace", (), Some(kwargs))?; - let ts = dt.call_method0("timestamp")?; - // s to us - let v = (ts.extract::()? * 1_000_000.0) as i64; - // we choose us as that is pythons default unit - Ok(AnyValue::Datetime(v, TimeUnit::Microseconds, &None).into()) - } }) } else if ob.is_none() { Ok(AnyValue::Null.into()) From b721accee1671fa56468e1d489265c743f7e925e Mon Sep 17 00:00:00 2001 From: Jeroen van Zundert Date: Sat, 1 Oct 2022 11:25:29 +0200 Subject: [PATCH 2/2] Revert conversion.rs --- py-polars/src/conversion.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/py-polars/src/conversion.rs b/py-polars/src/conversion.rs index 7c4fbc5a1f8f..dfb63716f481 100644 --- a/py-polars/src/conversion.rs +++ b/py-polars/src/conversion.rs @@ -482,6 +482,8 @@ impl<'s> FromPyObject<'s> for Wrap> { Ok(AnyValue::Utf8(v).into()) } else if ob.get_type().name()?.contains("datetime") { Python::with_gil(|py| { + // windows + #[cfg(target_arch = "windows")] { let kwargs = PyDict::new(py); kwargs.set_item("tzinfo", py.None())?; @@ -495,11 +497,25 @@ impl<'s> FromPyObject<'s> for Wrap> { .unwrap(); let loc_tz = localize.call1((dt, "UTC")); - let ts = loc_tz.unwrap().call_method0("timestamp")?; + loc_tz.call_method0("timestamp")?; // s to us let v = (ts.extract::()? * 1000_000.0) as i64; Ok(AnyValue::Datetime(v, TimeUnit::Microseconds, &None).into()) } + // unix + #[cfg(not(target_arch = "windows"))] + { + let datetime = PyModule::import(py, "datetime")?; + let timezone = datetime.getattr("timezone")?; + let kwargs = PyDict::new(py); + kwargs.set_item("tzinfo", timezone.getattr("utc")?)?; + let dt = ob.call_method("replace", (), Some(kwargs))?; + let ts = dt.call_method0("timestamp")?; + // s to us + let v = (ts.extract::()? * 1_000_000.0) as i64; + // we choose us as that is pythons default unit + Ok(AnyValue::Datetime(v, TimeUnit::Microseconds, &None).into()) + } }) } else if ob.is_none() { Ok(AnyValue::Null.into())