Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Mar 30, 2022
1 parent 475c304 commit 8e79583
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions lineapy/utils/tree_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from functools import wraps
from typing import Callable, Optional, TypeVar, cast
from typing import Callable, Dict, Iterable, List, Optional, TypeVar, cast

import rich
from rich.tree import Tree
Expand All @@ -22,6 +22,7 @@
__all__ = ["start_tree_log", "print_tree_log"]


# Mapping from class name to the color we shoudl use for its methods
# https://rich.readthedocs.io/en/stable/appendix/colors.html#appendix-colors
CLASS_TO_COLOR = {
"NodeTransformer": "blue",
Expand All @@ -31,15 +32,21 @@
"FunctionInspector": "red",
}

# List of classes we want to log
CLASSES = [NodeTransformer, Executor, Tracer, FunctionInspector]

# Current top tree we are logging from
TOP_TREE: Optional[Tree] = None
# The current tree for the function we are in.
CURRENT_TREE: Optional[Tree] = None

C = TypeVar("C", bound=Callable)


def start_tree_log(label: str) -> None:
"""
Starts logging, by overriding the classes, and also sets the top level label for the tree.
"""
global TOP_TREE, CURRENT_TREE
TOP_TREE = Tree(label=label)
CURRENT_TREE = TOP_TREE
Expand All @@ -48,13 +55,17 @@ def start_tree_log(label: str) -> None:

def print_tree_log() -> None:
"""
Print the tree log
Print the tree log with rich.
"""
global TOP_TREE
rich.print(TOP_TREE)


def tree_log(fn: C) -> C:
"""
Decorator to enable logging for a function. Should preserve its behavior, but logs whenever it is called.
"""

@wraps(fn)
def inner(*args, **kwargs):
global CURRENT_TREE
Expand All @@ -70,7 +81,15 @@ def inner(*args, **kwargs):
return cast(C, inner)


def render_call(fn, args, kwargs):
def render_call(
fn: Callable, args: Iterable[object], kwargs: Dict[str, object]
) -> str:
"""
Render the function, args, and kwargs as a string for printing with rich.
It uses some styles to color the classes and bold/underline the function names:
https://rich.readthedocs.io/en/stable/style.html
"""
qualname = fn.__qualname__
parts = qualname.split(".")
# If it's a classname.function, try coloring the classname
Expand All @@ -87,11 +106,15 @@ def render_call(fn, args, kwargs):


def override_classes():
"""
Override the __getattribute__ on the classes we want to track, so that whenever a method is retrieved from them,
it will wrap it in a logger first.
"""
for c in CLASSES:
c.__getattribute__ = tree_log_getattribute # type: ignore
c.__getattribute__ = _tree_log_getattribute # type: ignore


def tree_log_getattribute(self, item):
def _tree_log_getattribute(self, item):
value = super(type(self), self).__getattribute__(item) # type: ignore
if callable(value):
return tree_log(value)
Expand Down

0 comments on commit 8e79583

Please sign in to comment.