Skip to content

Commit

Permalink
fix: medrecord attribute tables overview (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
LauraBoenchenLB authored Sep 25, 2024
1 parent f1e44ad commit ae32479
Show file tree
Hide file tree
Showing 5 changed files with 386 additions and 344 deletions.
265 changes: 164 additions & 101 deletions medmodels/medrecord/_overview.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,133 @@
import copy
from datetime import datetime
from typing import Dict, List, Optional, Union

import polars as pl

from medmodels.medrecord.schema import AttributesSchema, AttributeType
from medmodels.medrecord.types import Attributes, EdgeIndex, NodeIndex
from medmodels.medrecord.types import (
AttributeInfo,
Attributes,
EdgeIndex,
Group,
MedRecordAttribute,
NodeIndex,
NumericAttributeInfo,
StringAttributeInfo,
TemporalAttributeInfo,
)


def extract_attribute_summary(
attribute_dict: Union[Dict[EdgeIndex, Attributes], Dict[NodeIndex, Attributes]],
attribute_dictionary: Union[
Dict[EdgeIndex, Attributes], Dict[NodeIndex, Attributes]
],
schema: Optional[AttributesSchema] = None,
decimal: int = 2,
) -> pl.DataFrame:
) -> Dict[
MedRecordAttribute,
Union[TemporalAttributeInfo, NumericAttributeInfo, StringAttributeInfo],
]:
"""Extracts a summary from a node or edge attribute dictionary.
Example:
┌────────────────┬──────────────────────────┐
│ Attribute ┆ Info │
│ --- ┆ --- │
│ str ┆ str │
╞════════════════╪══════════════════════════╡
│ diagnosis_time ┆ min: 1962-10-21 00:00:00 │
│ diagnosis_time ┆ max: 2024-04-12 00:00:00 │
│ duration_days ┆ min: 0.0 │
│ duration_days ┆ max: 3416.0 │
│ duration_days ┆ mean: 405.02 │
└────────────────┴──────────────────────────┘
Args:
attribute_dict (Union[Dict[EdgeIndex, Attributes], Dict[NodeIndex, Attributes]]):
Edges or Nodes and their attributes and values.
schema (Optional[AttributesSchema], optional): Attribute Schema for the group
nodes or edges. Defaults to None.
decimal (int): Decimal points to round the numeric values. Defaults to 2.
decimal (int): Decimal points to round the numeric values to. Defaults to 2.
Returns:
pl.DataFrame: Summary of node or edge attributes.
Dict[MedRecordAttribute, Union[TemporalAttributeInfo, NumericAttributeInfo,
StringAttributeInfo]: Summary of node or edge attributes.
"""
data = pl.DataFrame(data=[{"id": k, **v} for k, v in attribute_dict.items()])
data = pl.DataFrame(data=[{"id": k, **v} for k, v in attribute_dictionary.items()])

data_dict = {
"Attribute": [],
"Info": [],
}

attributes = [col for col in data.columns if col != "id"]
attributes.sort()
data_dict = {}

if not attributes:
return pl.DataFrame({"Attribute": ["-"], "Info": ["-"]})
attributes = sorted([col for col in data.columns if col != "id"])

for attribute in attributes:
attribute_values = data[attribute].drop_nulls()

if len(attribute_values) == 0:
attribute_info = ["-"]
elif schema and attribute in schema:
attribute_info = {"values": "-"}
# check if the attribute has as an attribute type defined in the schema
elif schema and attribute in schema and schema[attribute][1]:
if schema[attribute][1] == AttributeType.Continuous:
attribute_info = [
f"min: {attribute_values.min()}",
f"max: {attribute_values.max()}",
f"mean: {attribute_values.mean():.{decimal}f}",
]
attribute_info = _extract_numeric_attribute_info(attribute_values)
elif schema[attribute][1] == AttributeType.Temporal:
time_attribute = attribute_values.str.to_datetime()
attribute_info = [
f"min: {min(time_attribute).strftime('%Y-%m-%d %H:%M:%S')}",
f"max: {max(time_attribute).strftime('%Y-%m-%d %H:%M:%S')}",
]
elif schema[attribute][1] == AttributeType.Categorical:
attribute_info = [
_extract_string_attribute_info(
attribute_series=attribute_values,
long_string_suffix="categories",
short_string_prefix="Categories",
)
]
attribute_info = _extract_temporal_attribute_info(attribute_values)
else:
attribute_info = [
_extract_string_attribute_info(attribute_series=attribute_values)
]

## Without Schema
attribute_info = _extract_string_attribute_info(
attribute_series=attribute_values,
long_string_suffix="categories",
short_string_prefix="Categories",
)
# Without Schema
else:
if attribute_values.dtype.is_numeric():
attribute_info = [
f"min: {attribute_values.min()}",
f"max: {attribute_values.max()}",
f"mean: {attribute_values.mean():.{decimal}f}",
]
attribute_info = _extract_numeric_attribute_info(attribute_values)
elif attribute_values.dtype.is_temporal():
attribute_info = [
f"min: {min(attribute_values).strftime('%Y-%m-%d %H:%M:%S')}",
f"max: {max(attribute_values).strftime('%Y-%m-%d %H:%M:%S')}",
]
attribute_info = _extract_temporal_attribute_info(attribute_values)
else:
attribute_info = [
_extract_string_attribute_info(attribute_series=attribute_values)
]
attribute_info = _extract_string_attribute_info(
attribute_series=attribute_values
)

data_dict[attribute] = attribute_info

return data_dict


def _extract_numeric_attribute_info(
attribute_series: pl.Series,
) -> NumericAttributeInfo:
"""Extracts info about attributes with numeric format.
Args:
attribute_series (pl.Series): Series containing attribute values.
Returns:
NumericAttributeInfo: Dictionary containg attribute metrics.
"""
min = attribute_series.min()
max = attribute_series.max()
mean = attribute_series.mean()

# assertion to ensure correct typing
# never fails, because the series never contains None values and is always numeric
assert isinstance(min, (int, float))
assert isinstance(max, (int, float))
assert isinstance(mean, (int, float))

return {
"min": min,
"max": max,
"mean": mean,
}


def _extract_temporal_attribute_info(
attribute_series: pl.Series,
) -> TemporalAttributeInfo:
"""Extracts info about attributes with temporal format.
data_dict["Attribute"].extend([attribute] * len(attribute_info))
data_dict["Info"].extend(attribute_info)
Args:
attribute_series (pl.Series): Series containing attribute values.
return pl.DataFrame(data_dict)
Returns:
TemporalAttributeInfo: Dictionary containg attribute metrics.
"""
if not attribute_series.dtype.is_temporal():
if attribute_series.dtype.is_numeric():
attribute_series = attribute_series.cast(pl.Datetime)
else:
attribute_series = attribute_series.str.to_datetime()

return {
"min": min(attribute_series),
"max": max(attribute_series),
}


def _extract_string_attribute_info(
Expand All @@ -112,7 +136,7 @@ def _extract_string_attribute_info(
long_string_suffix: str = "unique values",
max_number_values: int = 5,
max_line_length: int = 100,
) -> str:
) -> StringAttributeInfo:
"""Extracts info about attributes with string format.
Args:
Expand All @@ -128,54 +152,93 @@ def _extract_string_attribute_info(
Defaults to 100.
Returns:
str: Attribute info string.
StringAttributeInfo: Dictionary containg attribute metrics.
"""
values = attribute_series.unique().sort()

values_string = f"{short_string_prefix}: {', '.join(list(values))}"

if (len(values) > max_number_values) | (len(values_string) > max_line_length):
return f"{len(values)} {long_string_suffix}"
else:
return values_string
values_string = f"{len(values)} {long_string_suffix}"

return {"values": values_string}


def prettify_table(table_info: pl.DataFrame) -> List[str]:
"""Takes a DataFrame and turns it into a list for printing a pretty table.
def prettify_table(
data: Dict[Group, AttributeInfo], header: List[str], decimal: int
) -> List[str]:
"""Takes a DataFrame and turns it into a list for displaying a pretty table.
Args:
table_info (pl.DataFrame): Table in DataFrame format.
data (Dict[Group, AttributeInfo]): Table info
stored in a dictionary.
header (List[str]): Header line consisting of column names for the table.
decimal (int): Decimal point to round the float values to.
Returns:
List[str]: List of lines for printing the table.
"""
table_info = table_info.with_columns(pl.exclude(pl.Utf8).cast(str))
lengths = [len(title) for title in header]

lengths = [
max(len(max(table_info[col], key=len)), len(col)) for col in table_info.columns
]
rows = []

print_table = [
"-" * (sum(lengths) + len(lengths) + 1),
" ".join(
[f"{head:<{lengths[i]}}" for i, head in enumerate(table_info.columns)]
),
"-" * (sum(lengths) + len(lengths) + 1),
]
info_order = ["min", "max", "mean", "values"]

old_row = [""] * len(table_info.columns)
for group in data.keys():
# determine longest group name and count
lengths[0] = max(len(str(group)), lengths[0])

for row in table_info.rows():
print_row = ""
lengths[1] = max(len(str(data[group]["count"])), lengths[1])

for i, elem in enumerate(row):
if (elem == old_row[i]) & (row[:i] == old_row[:i]):
elem = ""
print_row += f"{elem: <{lengths[i]}} "
row = [str(group), str(data[group]["count"]), "-", "-"]

print_table.append(print_row)
# in case of no attribute info, just keep Group name and count
if not data[group]["attribute"]:
rows.append(row)
continue

for attribute, info in data[group]["attribute"].items():
lengths[2] = max(len(str(attribute)), lengths[2])

# display attribute name only once
first_line = True

for key in sorted(info.keys(), key=lambda x: info_order.index(x)):
if not first_line:
row[0], row[1] = "", ""

row[2] = str(attribute) if first_line else ""

# displaying info based on the type
if "values" in info.keys():
row[3] = info[key]
else:
if isinstance(info[key], float):
row[3] = f"{key}: {info[key]:.{decimal}f}"
elif isinstance(info[key], datetime):
row[3] = info[key].strftime("%Y-%m-%d %H:%M:%S")
else:
row[3] = f"{key}: {info[key]}"

lengths[3] = max(len(row[3]), lengths[3])

rows.append(copy.deepcopy(row))

first_line = False

table = [
"-" * (sum(lengths) + len(lengths)),
"".join([f"{head.title():<{lengths[i]}} " for i, head in enumerate(header)]),
"-" * (sum(lengths) + len(lengths)),
]

old_row = row
table.extend(
[
"".join(f"{row[x]: <{lengths[x]}} " for x in range(len(lengths)))
for row in rows
]
)

print_table.append("-" * (sum(lengths) + len(lengths) + 1))
table.append("-" * (sum(lengths) + len(lengths)))

return print_table
return table
Loading

0 comments on commit ae32479

Please sign in to comment.