Skip to content

Commit

Permalink
Merge pull request LineaLabs#552 from LineaLabs/integration-tests
Browse files Browse the repository at this point in the history
Check integration tests status with new code
  • Loading branch information
saulshanabrook committed Mar 30, 2022
2 parents b7e674f + 6619d85 commit 0f17244
Show file tree
Hide file tree
Showing 33 changed files with 854 additions and 190 deletions.
3 changes: 3 additions & 0 deletions lineapy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def notebook(
lineapy notebook my_notebook.ipynb notebook_file_system lineapy.file_system
"""
logger.info("Creating in memory notebook")
# Create the notebook:
notebook = nbformat.read(file, nbformat.NO_CONVERT)
notebook["cells"].append(
Expand All @@ -88,9 +89,11 @@ def notebook(
# Run the notebook:
setup_ipython_dir()
exec_proc = ExecutePreprocessor(timeout=None)
logger.info("Executing notebook")
exec_proc.preprocess(notebook)

# Print the slice:
logger.info("Printing slice")
# TODO: duplicated with `get` but no context set, should rewrite eventually
# to not duplicate
db = RelationalLineaDB.from_environment()
Expand Down
6 changes: 5 additions & 1 deletion lineapy/execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from types import ModuleType
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional

from lineapy.execution.globals_dict import GlobalsDict, GlobalsDictResult
Expand Down Expand Up @@ -129,7 +130,10 @@ def set_context(
_current_context = ExecutionContext(
_input_node_ids=input_node_ids,
_input_globals_mutable={
k: is_mutable(v) for k, v in global_name_to_value.items()
# Don't consider modules or classes as mutable inputs, so that any code which uses a module
# we assume it doesn't mutate it.
k: is_mutable(v) and not isinstance(v, (ModuleType, type))
for k, v in global_name_to_value.items()
},
node=node,
executor=executor,
Expand Down
2 changes: 2 additions & 0 deletions lineapy/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def _execute_call(
except Exception as exc:
raise UserException(exc, RemoveFrames(1), *changes)
finally:
logger.debug("Tearing down context")
# Check what has been changed and accessed in the globals
# Do this in a finally, so its always torn down even after exceptions
globals_result = teardown_context()
Expand All @@ -331,6 +332,7 @@ def _execute_call(
NOTE: we have a near term eng goal to refactor how side-effect is
handled.
"""
logger.debug("Resolving side effects")
side_effects = chain(
map(self._process_global_side_effect, globals_result.side_effects),
(
Expand Down
62 changes: 46 additions & 16 deletions lineapy/execution/inspect_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from io import IOBase
from types import ModuleType
from typing import Callable, Dict, Hashable, Iterable, List, Optional, Tuple

Expand Down Expand Up @@ -45,19 +46,28 @@ def is_mutable(obj: object) -> bool:
Returns true if the object is mutable.
"""

# Assume all hashable objects are immutable
# I (yifan) think this is incorrect, but keeping the dead code
# here in case we run into some issues again

# try:
# hash(obj)
# except Exception:
# return True
# return False
if isinstance(obj, (str, int, bool, float, tuple, frozenset)):
return False
else:
# We have to special case any types which are hashable, but are mutable.
# Since there is no way to see if a clase is mutable a priori, we could add a list of types
# like this to our annotations
mutable_hashable_types: Tuple[type, ...] = (
ModuleType,
type,
type(iter([])),
IOBase,
)
if "sklearn.base" in sys.modules:
mutable_hashable_types += (sys.modules["sklearn.base"].BaseEstimator,) # type: ignore

# Special case some mutable hashable types
if isinstance(obj, mutable_hashable_types):
return True

# Otherwise assume all hashable objects are immutable
try:
hash(obj)
except Exception:
return True
return False


def validate(item: Dict) -> Optional[ModuleAnnotation]:
Expand Down Expand Up @@ -284,13 +294,14 @@ def _parse(self) -> None:
"""
Parses all specs which are for modules we have imported
"""
for module in list(self.specs.keys()):
if module not in sys.modules:
for module_name in list(self.specs.keys()):
module = get_imported_module(module_name)
if not module:
continue
self.parsed.add_annotations(
sys.modules[module],
module,
# Pop the spec once we have processed it
self.specs.pop(module),
self.specs.pop(module_name),
)

def __post_init__(self):
Expand All @@ -316,3 +327,22 @@ def inspect(
)
if processed_side_effect:
yield processed_side_effect


def get_imported_module(name: str) -> Optional[ModuleType]:
"""
Return a module, if it has been imported.
Also handles the corner case where a submodule has not been imported, but is accessible
as an attribute on the parent module. This is needed for the example `tensorflow.keras.utils`, which
is not imported when importing `tensorflow`, but is accessible as a property of `tensorflow`.
"""
if name in sys.modules:
return sys.modules[name]
*parent_names, submodule_name = name.split(".")
if not parent_names:
return None
parent_module = get_imported_module(".".join(parent_names))
if not parent_module:
return None
return getattr(parent_module, submodule_name, None)
16 changes: 16 additions & 0 deletions lineapy/external.annotations.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@
side_effects:
- mutated_value:
external_state: file_system
- module: tensorflow.keras.utils
annotations:
- criteria:
function_name: get_file
side_effects:
- dependency:
external_state: file_system
- mutated_value:
external_state: file_system
- module: torch
annotations:
- criteria:
function_name: manual_seed
side_effects:
- mutated_value:
self_ref: SELF_REF
- module: torch.jit._script
annotations:
- criteria:
Expand Down
4 changes: 3 additions & 1 deletion lineapy/instrumentation/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,17 @@ def process_node(self, node: Node) -> None:
# Update the graph from the side effects of the node,
# If an artifact could not be created, quitely return without saving the node to the DB.
##
logger.debug("Executing node %s", node)
try:
side_effects = self.executor.execute_node(
node,
{k: v.id for k, v in self.variable_name_to_node.items()},
)
except ArtifactSaveException as exc_info:
logger.error("Artifact could not be saved.")
logger.debug(exc_info)
logger.info(exc_info)
return
logger.debug("Processing side effects")

# Iterate through each side effect and process it, depending on its type
for e in side_effects:
Expand Down
23 changes: 22 additions & 1 deletion lineapy/internal.annotations.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,28 @@
side_effects:
- mutated_value:
positional_argument_index: 0

- criteria: # inplace ops
function_names:
- iadd
- iand
- iconcat
- ifloordiv
- ilshift
- imod
- imul
- imatmul
- ior
- ipow
- irshift
- isub
- itruediv
- ixor
side_effects:
- mutated_value:
positional_argument_index: 0
- views:
- positional_argument_index: 0
- result: RESULT
- module: io
annotations:
- criteria:
Expand Down
45 changes: 37 additions & 8 deletions lineapy/system_tracing/_object_side_effects_to_side_effects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Mapping, Set
Expand All @@ -22,6 +23,8 @@
)
from lineapy.utils.lineabuiltins import LINEA_BUILTINS

logger = logging.getLogger(__name__)

# Mapping of the ID of each external state object to its pointer
EXTERNAL_STATE_IDS: Dict[int, ExecutorPointer] = {
id(b): b for b in LINEA_BUILTINS.values() if isinstance(b, ExternalState)
Expand Down Expand Up @@ -57,8 +60,7 @@ def object_side_effects_to_side_effects(
# First track all the views and mutations in terms of objects
# (no Nodes reference)
tracker = ObjectMutationTracker()

for object_side_effect in object_side_effects:
for i, object_side_effect in enumerate(object_side_effects):
tracker.process_side_effect(object_side_effect)

# Mapping of object ids for the objects we care about, input nodes &
Expand All @@ -68,23 +70,26 @@ def object_side_effects_to_side_effects(
# ```y = [1]
# b = y
# ```
logger.debug("Creating list of objects we care about")

object_id_to_pointers: Dict[int, List[ExecutorPointer]] = defaultdict(list)
for object_id, pointer in EXTERNAL_STATE_IDS.items():
object_id_to_pointers[object_id].append(pointer)
for linea_id, obj in input_nodes.items():
object_id_to_pointers[id(obj)].append(ID(LineaID(linea_id)))

# Return all implicit dependencies
logger.debug("Returning implicit dependencies")

for id_ in tracker.implicit_dependencies:
for pointer in object_id_to_pointers[id_]:
yield ImplicitDependencyNode(pointer)

# Return all mutated nodes
logger.debug("Returning mutated nodes")
for id_ in tracker.mutated:
for pointer in object_id_to_pointers[id_]:
yield MutatedNode(pointer)

# Return all views
logger.debug("Returning views")

# For the views, we care about whether they include the outputs as well,
# so lets add those to the objects we care about
Expand Down Expand Up @@ -140,9 +145,33 @@ def process_side_effect(
self, object_side_effect: ObjectSideEffect
) -> None:
if isinstance(object_side_effect, ViewOfObjects):
set_as_viewers_generic(
[id(o) for o in object_side_effect.objects], self.viewers
)
# Special case for two objects, which is most calls, to speed up processing
if len(object_side_effect.objects) == 2:
l_obj, r_obj = object_side_effect.objects
l_id = id(l_obj)
r_id = id(r_obj)
if l_id == r_id:
return
l_viewers = self.viewers[l_id]
r_viewers = self.viewers[r_id]

already_viewers = l_id in r_viewers
if already_viewers:
return

# Since they are not views of each other, we know that their viewers are mutually exclusive, so to find
# the intersection we can just combine them both and not worry about duplicates
l_viewers_copy = list(l_viewers)
l_viewers.extend(r_viewers)
l_viewers.append(r_id)

r_viewers.extend(l_viewers_copy)
r_viewers.append(l_id)

else:
ids = [id(o) for o in object_side_effect.objects]
set_as_viewers_generic(ids, self.viewers)

elif isinstance(object_side_effect, MutatedObject):
id_ = id(object_side_effect.object)
self.add_mutated(id_)
Expand Down
1 change: 1 addition & 0 deletions lineapy/system_tracing/exec_and_record_function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def exec_and_record_function_calls(
"""
Execute the code while recording all the function calls which originate from the code object.
"""
logger.info("Execing code")
trace_func = TraceFunc(code)
try:
settrace(trace_func)
Expand Down
7 changes: 7 additions & 0 deletions lineapy/system_tracing/function_calls_to_side_effects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Iterable, Mapping

from lineapy.data.types import LineaID
Expand All @@ -11,6 +12,8 @@
)
from lineapy.system_tracing.function_call import FunctionCall

logger = logging.getLogger(__name__)


def function_calls_to_side_effects(
function_inspector: FunctionInspector,
Expand All @@ -26,9 +29,13 @@ def function_calls_to_side_effects(
:param input_nodes: Mapping of node ID to value for all the nodes that were passed in to this execution.
:param output_globals: Mapping of global identifier to the value of all globals that were set during this execution.
"""
logger.info("Converting function calls to object side effects")

object_side_effects = function_calls_to_object_side_effects(
function_inspector, function_calls
)

logger.info("Converting object side effects to node side effects")
return object_side_effects_to_side_effects(
object_side_effects, input_nodes, output_globals
)
3 changes: 2 additions & 1 deletion lineapy/utils/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import os

from rich.console import Console
from rich.logging import RichHandler

# https://rich.readthedocs.io/en/stable/logging.html#logging-handler
Expand All @@ -33,5 +34,5 @@ def configure_logging(level=None, LOG_SQL=False):
level=level,
format=FORMAT,
datefmt="[%X]",
handlers=[RichHandler()],
handlers=[RichHandler(console=Console(stderr=True))],
)
2 changes: 1 addition & 1 deletion lineapy/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_value_type(val: Any) -> Optional[ValueType]:
if "PIL" in sys.modules:
import PIL

if sys.version_info >= (3, 8):
if hasattr(PIL, "PngImagePlugin"):
if isinstance(val, PIL.PngImagePlugin.PngImageFile):
return ValueType.chart
if isinstance(val, PIL.Image.Image):
Expand Down
4 changes: 2 additions & 2 deletions lineapy/visualizer/graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"pink": "#fddaec",
"grey": "#f2f2f2",
}
ALPHA = 50
ALPHA = 60
# Alpha fraction from 0 to 255 to apply to lighten the second colors
# Make all the secondary colors lighter, by applying an alpha
# Graphviz takes this as an alpha value in hex form
Expand All @@ -97,7 +97,7 @@ def color(original_color: str, highlighted: bool) -> str:

# Mapping of each node type to its color
COLORS: Dict[ColorableType, str] = defaultdict(
lambda: BREWER_PASTEL["grey"],
lambda: "#d4d4d4", # use a slightly darker grey by default
{
NodeType.CallNode: BREWER_PASTEL["pink"],
NodeType.LiteralNode: BREWER_PASTEL["green"],
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ norecursedirs =
.ipynb_checkpoints
docs
sources
envs
filterwarnings =
ignore:A private pytest class or function was used.:pytest.PytestDeprecationWarning
addopts =
Expand Down
Loading

0 comments on commit 0f17244

Please sign in to comment.