Skip to content

Commit

Permalink
Merge pull request LineaLabs#583 from LineaLabs/LIN-292-use-get-code
Browse files Browse the repository at this point in the history
LIN-292 use get_code instead of .code
  • Loading branch information
lionsardesai authored Apr 14, 2022
2 parents 4393af7 + 3fbddf9 commit 6ae9131
Show file tree
Hide file tree
Showing 19 changed files with 8,614 additions and 8,478 deletions.
1,856 changes: 928 additions & 928 deletions docs/source/tutorials/00_api_basics.ipynb

Large diffs are not rendered by default.

1,236 changes: 618 additions & 618 deletions examples/.dev/Analysis Scope.ipynb

Large diffs are not rendered by default.

1,856 changes: 928 additions & 928 deletions examples/tutorials/00_api_basics.ipynb

Large diffs are not rendered by default.

3,574 changes: 1,787 additions & 1,787 deletions examples/use-cases/predict_house_price/01_preprocessing.ipynb

Large diffs are not rendered by default.

2,830 changes: 1,415 additions & 1,415 deletions examples/use-cases/predict_house_price/02_modeling.ipynb

Large diffs are not rendered by default.

2,572 changes: 1,286 additions & 1,286 deletions examples/use-cases/predict_house_price/04_preprocessing_v2.ipynb

Large diffs are not rendered by default.

2,798 changes: 1,399 additions & 1,399 deletions examples/use-cases/predict_house_price/05_modeling_v2.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions lineapy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def notebook(
date_created=artifact.date_created,
)
api_artifact.version = artifact.version or VERSION_PLACEHOLDER
logger.info(api_artifact.code)
logger.info(api_artifact.get_code())


@linea_cli.command()
Expand Down Expand Up @@ -163,7 +163,7 @@ def file(
date_created=artifact.date_created,
)
api_artifact.version = artifact.version or VERSION_PLACEHOLDER
logger.info(api_artifact.code)
logger.info(api_artifact.get_code())


def generate_save_code(
Expand Down
82 changes: 68 additions & 14 deletions lineapy/graph_reader/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import logging
import re
from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Optional, cast
Expand Down Expand Up @@ -63,40 +64,56 @@ def value(self) -> object:
"""
Get and return the value of the artifact
"""
value = self.db.get_node_value_from_db(
self._node_id, self._execution_id
)
if not value:
raise ValueError("No value saved for this node")
if value.value is None:
value = self._get_value_path()
if value is None:
return None
else:
# TODO - set unicode etc here

track(GetValueEvent(has_value=True))
with open(value.value, "rb") as f:
with open(value, "rb") as f:
return FilePickler.load(f)

def _get_value_path(
self, other: Optional[ArtifactORM] = None
) -> Optional[str]:
"""
Get the path to the value of the artifact.
:param other: Additional argument to let you query another artifact's value path.
This is set to be optional and if its not set, we will use the current artifact
"""
if other is not None:
value = self.db.get_node_value_from_db(
other.node_id, other.execution_id
)
else:
value = self.db.get_node_value_from_db(
self._node_id, self._execution_id
)
if not value:
raise ValueError("No value saved for this node")
return value.value

@property
def _subgraph(self) -> Graph:
"""
Return the slice subgraph for the artifact
"""
return get_slice_graph(self._graph, [self._node_id])

@property
def code(self) -> str:
def get_code(self, use_lineapy_serialization=True) -> str:
"""
Return the slices code for the artifact
"""
# FIXME: this seems a little heavy to just get the slice?
track(
GetCodeEvent(use_lineapy_serialization=True, is_session_code=False)
)
return get_source_code_from_graph(self._subgraph)
return self._de_linealize_code(
get_source_code_from_graph(self._subgraph),
use_lineapy_serialization,
)

@property
def session_code(self) -> str:
def get_session_code(self, use_lineapy_serialization=True) -> str:
"""
Return the raw session code for the artifact. This will include any
comments and non-code lines.
Expand All @@ -106,7 +123,44 @@ def session_code(self) -> str:
track(
GetCodeEvent(use_lineapy_serialization=False, is_session_code=True)
)
return self.db.get_source_code_for_session(self._session_id)
return self._de_linealize_code(
self.db.get_source_code_for_session(self._session_id),
use_lineapy_serialization,
)

def _de_linealize_code(
self, code: str, use_lineapy_serialization: bool
) -> str:
"""
De-linealize the code by removing any lineapy api references
"""
if use_lineapy_serialization:
return code
else:
lineapy_pattern = re.compile(
r"(lineapy.(save\(([\w]+),\s*[\"\']([\w\-\s]+)[\"\']\)|get\([\"\']([\w\-\s]+)[\"\']\).value))"
)
# init swapped version

def replace_fun(match):
if match.group(2).startswith("save"):
# TODO - this can be another artifact. find it using the match.group(4)
# dep_artifact = self.db.get_artifact_by_name(match.group(4))
path_to_use = self._get_value_path()
return f'pickle.dump({match.group(3)},open("{path_to_use}","wb"))'

elif match.group(2).startswith("get"):
# this typically will be a different artifact.
dep_artifact = self.db.get_artifact_by_name(match.group(5))
path_to_use = self._get_value_path(dep_artifact)
return f'pickle.load(open("{path_to_use}","rb"))'

swapped, replaces = lineapy_pattern.subn(replace_fun, code)
if replaces > 0:
swapped = "import pickle\n" + swapped
logger.debug("replaces made: %s", replaces)

return swapped

@property
def _graph(self) -> Graph:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
code="""import lineapy
x = 1
res = lineapy.save(x, "x")
slice = res.code
slice = res.get_code()
value = res.value
""",
location=PosixPath("[source file path]"),
Expand Down Expand Up @@ -72,25 +72,34 @@
).id,
],
)
call_3 = CallNode(
call_4 = CallNode(
source_location=SourceLocation(
lineno=4,
col_offset=8,
end_lineno=4,
end_col_offset=16,
end_col_offset=22,
source_code=source_1.id,
),
function_id=LookupNode(
name="getattr",
).id,
positional_args=[
call_2.id,
LiteralNode(
value="code",
function_id=CallNode(
source_location=SourceLocation(
lineno=4,
col_offset=8,
end_lineno=4,
end_col_offset=20,
source_code=source_1.id,
),
function_id=LookupNode(
name="getattr",
).id,
],
positional_args=[
call_2.id,
LiteralNode(
value="get_code",
).id,
],
).id,
)
call_4 = CallNode(
call_5 = CallNode(
source_location=SourceLocation(
lineno=5,
col_offset=8,
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion tests/end_to_end/test_artifact_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_slice_preserved_after_tracer_dereferenced(execute):
artifact_f.db, artifact_f._session_id
)
# this is an additional redundant step to make sure our orig artifact is correct
assert artifact_f.code == code_body
assert artifact_f.get_code() == code_body
# and here we only use the tracer context to ensure we can retrieve the
# slice from db directly
assert second_context.slice("deferencedy") == code_body
35 changes: 34 additions & 1 deletion tests/end_to_end/test_artifacts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime, timedelta

import pytest

from lineapy.utils.constants import VERSION_DATE_STRING


Expand Down Expand Up @@ -139,10 +141,41 @@ def test_artifact_session_code(execute):
print(x)"""
tracer = execute(importl + code_body + artifact_f_save, snapshot=False)
artifact = tracer.values["use_y"]
assert artifact.session_code == importl + code_body + artifact_f_save
assert artifact.get_session_code() == importl + code_body + artifact_f_save
assert (
artifact.db.get_session_context(
artifact._session_id
).environment_type.name
== "SCRIPT"
)


@pytest.mark.xfail(
reason="fails because the reexecution of graph created a new random file that saves the value of cleanedx"
)
def test_artifact_code_without_lineapy(execute):
code = """import lineapy
x = 1
savepath = lineapy.save(x, "cleanedx")
cleanedx = lineapy.get("cleanedx").value
y = cleanedx + 1
y_art = lineapy.save(y, "y")
"""
t2 = execute(code, snapshot=False)
saved_path = t2.values["savepath"]._get_value_path()
artifact = t2.values["y_art"]
assert (
artifact.get_code()
== """import lineapy
cleanedx = lineapy.get("cleanedx").value
y = cleanedx + 1
"""
)
assert (
artifact.get_code(False)
== f"""import pickle
import lineapy
cleanedx = pickle.load(open("{saved_path}", "rb"))
y = cleanedx + 1
"""
)
2 changes: 1 addition & 1 deletion tests/end_to_end/test_linea_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ def test_save_returns_artifact(execute):
c = """import lineapy
x = 1
res = lineapy.save(x, "x")
slice = res.code
slice = res.get_code()
value = res.value
"""

Expand Down
Loading

0 comments on commit 6ae9131

Please sign in to comment.