diff --git a/lineapy/cli/cli.py b/lineapy/cli/cli.py index c33b0abd2..01e6b779a 100644 --- a/lineapy/cli/cli.py +++ b/lineapy/cli/cli.py @@ -5,7 +5,7 @@ import sys import tempfile from io import TextIOWrapper -from typing import List +from typing import List, Optional import click import nbformat @@ -42,7 +42,17 @@ def linea_cli(): @click.argument("file", type=click.File()) @click.argument("artifact_name") @click.argument("artifact_value", type=str) -def notebook(file: TextIOWrapper, artifact_name: str, artifact_value: str): +@click.option( + "--visualize-slice", + type=click.Path(dir_okay=False, path_type=pathlib.Path), + help="Create a visualization for the sliced code, save it to this path", +) +def notebook( + file: TextIOWrapper, + artifact_name: str, + artifact_value: str, + visualize_slice: Optional[pathlib.Path], +): """ Executes the notebook FILE, saves the value ARTIFACT_VALUE with name ARTIFACT_NAME, and prints the sliced code. @@ -58,11 +68,18 @@ def notebook(file: TextIOWrapper, artifact_name: str, artifact_value: str): notebook = nbformat.read(file, nbformat.NO_CONVERT) notebook["cells"].append( nbformat.v4.new_code_cell( - "import lineapy\n" - # Save to a new variable first, so that if artifact value is composite, the slice of creating it - # won't include the `lineapy.save` line. - f"linea_artifact_value = {artifact_value}\n" - f"lineapy.save(linea_artifact_value, {repr(artifact_name)})" + ( + "import lineapy\n" + # Save to a new variable first, so that if artifact value is composite, the slice of creating it + # won't include the `lineapy.save` line. + f"linea_artifact_value = {artifact_value}\n" + f"linea_artifact = lineapy.save(linea_artifact_value, {repr(artifact_name)})\n" + ) + + ( + f"linea_artifact.visualize({repr(str(visualize_slice.resolve()))})" + if visualize_slice + else "" + ) ) ) diff --git a/lineapy/graph_reader/apis.py b/lineapy/graph_reader/apis.py index baef3627c..e0162d0b9 100644 --- a/lineapy/graph_reader/apis.py +++ b/lineapy/graph_reader/apis.py @@ -111,17 +111,19 @@ def to_airflow(self, filename: Optional[str] = None) -> Path: ) return path - def visualize(self) -> None: + def visualize(self, path: Optional[str]) -> None: """ Displays the graph for this artifact. + + If a path is provided, will save it to that file instead. """ from lineapy.visualizer import Visualizer - display( - Visualizer.for_public_node( - self._graph, self.node_id - ).ipython_display_object() - ) + visualizer = Visualizer.for_public_node(self._graph, self.node_id) + if path: + visualizer.render_pdf_file(path) + else: + display(visualizer.ipython_display_object()) class LineaCatalog: diff --git a/lineapy/visualizer/optimize_svg.py b/lineapy/visualizer/optimize_svg.py index dddbfe8a5..8cf0534c7 100644 --- a/lineapy/visualizer/optimize_svg.py +++ b/lineapy/visualizer/optimize_svg.py @@ -4,8 +4,7 @@ import subprocess import tempfile - -from path import Path +from pathlib import Path # https://github.com/scour-project/scour#usage OPTIONS = [ diff --git a/setup.py b/setup.py index 32a8cde02..7a40b41bb 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,7 @@ def version(path): "pyyaml", "asttokens", "isort", + "graphviz", ], extras_require={ "dev": [ @@ -95,7 +96,6 @@ def version(path): "nbval", "coveralls", "seaborn", - "graphviz", "pre-commit", "SQLAlchemy[mypy]>=1.4.0", "sphinx",