Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: interface for comparer and builder #198

Open
wants to merge 12 commits into
base: epic/157-create-an-evaluatorcomparer
Choose a base branch
from
10 changes: 5 additions & 5 deletions medmodels/medrecord/_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from medmodels.medrecord.types import (
AttributeInfo,
Attributes,
AttributeSummary,
EdgeIndex,
Group,
MedRecordAttribute,
Expand All @@ -25,7 +26,7 @@ def extract_attribute_summary(
schema: Optional[AttributesSchema] = None,
) -> Dict[
MedRecordAttribute,
Union[TemporalAttributeInfo, NumericAttributeInfo, StringAttributeInfo],
AttributeInfo,
]:
"""Extracts a summary from a node or edge attribute dictionary.

Expand All @@ -37,8 +38,7 @@ def extract_attribute_summary(
decimal (int): Decimal points to round the numeric values to. Defaults to 2.

Returns:
Dict[MedRecordAttribute, Union[TemporalAttributeInfo, NumericAttributeInfo,
StringAttributeInfo]: Summary of node or edge attributes.
Dict[MedRecordAttribute, AttributeInfo]: Summary of node or edge attributes.
"""
data = pl.DataFrame(data=[{"id": k, **v} for k, v in attribute_dictionary.items()])

Expand Down Expand Up @@ -165,12 +165,12 @@ def _extract_string_attribute_info(


def prettify_table(
data: Dict[Group, AttributeInfo], header: List[str], decimal: int
data: Dict[Group, AttributeSummary], header: List[str], decimal: int
LauraBoenchenLB marked this conversation as resolved.
Show resolved Hide resolved
) -> List[str]:
"""Takes a DataFrame and turns it into a list for displaying a pretty table.

Args:
data (Dict[Group, AttributeInfo]): Table info
data (Dict[Group, AttributeSummary]): Table info
stored in a dictionary.
header (List[str]): Header line consisting of column names for the table.
decimal (int): Decimal point to round the float values to.
Expand Down
18 changes: 10 additions & 8 deletions medmodels/medrecord/medrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from medmodels.medrecord.querying import EdgeOperand, EdgeQuery, NodeOperand, NodeQuery
from medmodels.medrecord.schema import Schema
from medmodels.medrecord.types import (
AttributeInfo,
Attributes,
AttributeSummary,
EdgeIndex,
EdgeIndexInputList,
EdgeInput,
Expand Down Expand Up @@ -77,20 +77,20 @@ def process_edges_dataframe(
class OverviewTable:
"""Class for the node/edge group overview table."""

data: Dict[Group, AttributeInfo]
data: Dict[Group, AttributeSummary]
group_header: str
decimal: int

def __init__(
self,
data: Dict[Group, AttributeInfo],
data: Dict[Group, AttributeSummary],
LauraBoenchenLB marked this conversation as resolved.
Show resolved Hide resolved
group_header: str,
decimal: int,
):
"""Initializes the OverviewTable class.

Args:
data (Dict[Group, AttributeInfo]): Dictionary containing attribute info for edges/nodes.
data (Dict[Group, AttributeSummary]): Dictionary containing attribute info for edges/nodes.
group_header (str): Header for group column, i.e. 'Group Nodes'.
decimal (int): Decimal point to round the float values to.
"""
Expand Down Expand Up @@ -1280,11 +1280,12 @@ def clone(self) -> MedRecord:

def _describe_group_nodes(
LauraBoenchenLB marked this conversation as resolved.
Show resolved Hide resolved
self,
) -> Dict[Group, AttributeInfo]:
) -> Dict[Group, AttributeSummary]:
"""Creates a summary of group nodes and their attributes.

Returns:
pl.DataFrame: Dataframe with all nodes in medrecord groups and their attributes.
Dict[Group, AttributeSummary]: Dictionary with all nodes in medrecord groups
and their attributes.
"""
nodes_info = {}
grouped_nodes = []
Expand Down Expand Up @@ -1316,11 +1317,12 @@ def _describe_group_nodes(

def _describe_group_edges(
self,
) -> Dict[Group, AttributeInfo]:
) -> Dict[Group, AttributeSummary]:
LauraBoenchenLB marked this conversation as resolved.
Show resolved Hide resolved
"""Creates a summary of edges connecting group nodes and the edge attributes.

Returns:
pl.DataFrame: DataFrame with an overview of edges connecting group nodes.
Dict[Group, AttributeSummary]: Dictionary with an overview of edges
connecting group nodes.
"""
edges_info = {}
grouped_edges = []
Expand Down
25 changes: 15 additions & 10 deletions medmodels/medrecord/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
Tuple[NodeIndex, NodeIndex, AttributesInput],
]


#: A type alias for input to a Polars DataFrame for nodes.
PolarsNodeDataFrameInput: TypeAlias = Tuple[pl.DataFrame, str]

Expand All @@ -87,6 +88,10 @@
#: A type alias for input to a Pandas DataFrame for edges.
PandasEdgeDataFrameInput: TypeAlias = Tuple[pd.DataFrame, str, str]

AttributeInfo: TypeAlias = Union[
"TemporalAttributeInfo", "NumericAttributeInfo", "StringAttributeInfo"
LauraBoenchenLB marked this conversation as resolved.
Show resolved Hide resolved
]

#: A type alias for input to a node.
NodeInput = Union[
NodeTuple,
Expand Down Expand Up @@ -115,16 +120,6 @@ class GroupInfo(TypedDict):
edges: List[EdgeIndex]


class AttributeInfo(TypedDict):
"""A dictionary containing info about nodes/edges and their attributes."""

count: int
attribute: Dict[
MedRecordAttribute,
Union[TemporalAttributeInfo, NumericAttributeInfo, StringAttributeInfo],
]


class TemporalAttributeInfo(TypedDict):
"""Dictionary for a temporal attribute and its metrics."""

Expand All @@ -146,6 +141,16 @@ class StringAttributeInfo(TypedDict):
values: str


class AttributeSummary(TypedDict):
"""A dictionary containing info about nodes/edges and their attributes."""

count: int
attribute: Dict[
MedRecordAttribute,
AttributeInfo,
]


def is_medrecord_attribute(value: object) -> TypeIs[MedRecordAttribute]:
"""Check if a value is a MedRecord attribute.

Expand Down
83 changes: 83 additions & 0 deletions medmodels/statistic_evaluations/evaluate_compare/compare.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

from typing import Dict, List, Tuple, TypedDict

from medmodels.medrecord.types import (
AttributeInfo,
AttributeSummary,
Group,
MedRecordAttribute,
NodeIndex,
)
from medmodels.statistic_evaluations.evaluate_compare.evaluate import CohortEvaluator

class CohortSummary(TypedDict):
"""Dictionary for the cohort summary."""

attribute_info: Dict[Group, AttributeSummary]
top_k_concepts: Dict[Group, List[NodeIndex]]

class DistanceSummary(TypedDict):
"""Dictonary for the Jensen-Shannon-Divergence and normalized distance between
distributions."""

js_divergence: float
distance: float

class ComparerSummary(TypedDict):
"""Dictionary for comparing results."""

attribute_tests: Dict[MedRecordAttribute, List[TestSummary]]
concepts_tests: Dict[Group, List[TestSummary]]
concepts_distance: Dict[Group, DistanceSummary]

class TestSummary(TypedDict):
"""Dictionary for hypothesis test results."""

test: str
Hypothesis: str
not_reject: bool
p_value: float

class CohortComparer:
@staticmethod
def compare_cohort_attribute(
cohorts: List[CohortEvaluator],
attribute: MedRecordAttribute,
) -> Dict[str, AttributeInfo]: ...
@staticmethod
def test_difference_attribute(
cohorts_attribute: List[CohortEvaluator],
attribute: MedRecordAttribute,
significance_level: float,
) -> List[TestSummary]: ...
@staticmethod
def compare_cohorts(
cohorts: List[CohortEvaluator],
) -> Dict[str, CohortSummary]: ...
@staticmethod
def test_difference_cohort_attributes(
cohorts: List[CohortEvaluator],
significance_level: float,
) -> Dict[str, List[TestSummary]]: ...
@staticmethod
def calculate_absolute_relative_difference(
control_group: CohortEvaluator,
case_group: CohortEvaluator,
) -> Tuple[float, Dict[MedRecordAttribute, float]]: ...
@staticmethod
def test_difference_top_k_concepts(
cohorts: List[CohortEvaluator],
top_k: int,
significance_level: float,
) -> Dict[Group, List[TestSummary]]: ...
@staticmethod
def calculate_distance_concepts(
cohorts: List[CohortEvaluator],
) -> Dict[Group, DistanceSummary]: ...
@staticmethod
def full_comparison(
cohorts: List[CohortEvaluator],
top_k: int,
significance_level: float,
) -> Tuple[Dict[str, CohortSummary], ComparerSummary]: ...
44 changes: 44 additions & 0 deletions medmodels/statistic_evaluations/evaluate_compare/evaluate.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from typing import Dict, List, Optional, Tuple, Union

from medmodels.medrecord.medrecord import MedRecord
from medmodels.medrecord.querying import NodeQuery
from medmodels.medrecord.schema import AttributeType
from medmodels.medrecord.types import (
AttributeSummary,
Group,
GroupInputList,
MedRecordAttribute,
NodeIndex,
)

class CohortEvaluator:
medrecord: MedRecord
name: str
cohort_group: Group
time_attribute: MedRecordAttribute
attributes: Optional[Dict[str, MedRecordAttribute]]
concepts_groups: Optional[GroupInputList]
attribute_summary: Dict[Group, AttributeSummary]
attribute_types: Dict[MedRecordAttribute, AttributeType]

def __init__(
self,
medrecord: MedRecord,
name: str,
cohort_group: Union[Group, NodeQuery] = "patients",
time_attribute: MedRecordAttribute = "time",
attributes: Optional[Dict[str, MedRecordAttribute]] = None,
concepts_groups: Optional[GroupInputList] = None,
) -> None: ...
def get_concept_counts(
self,
) -> List[Tuple[NodeIndex, int]]: ...
def get_top_k_concepts(
self,
top_k: int,
) -> List[NodeIndex]: ...
MarIniOnz marked this conversation as resolved.
Show resolved Hide resolved
def get_attribute_summary(
self,
) -> Dict[Group, AttributeSummary]: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import polars as pl

from medmodels.medrecord import MedRecord
from medmodels.medrecord.querying import NodeQuery
from medmodels.medrecord.schema import AttributeType
from medmodels.medrecord.types import (
NumericAttributeInfo,
StringAttributeInfo,
TemporalAttributeInfo,
)

def determine_attribute_type(attribute_values: pl.Series) -> AttributeType: ...
def get_continuous_attribute_statistics(
medrecord: MedRecord, attribute_query: NodeQuery
) -> NumericAttributeInfo: ...
def get_temporal_attribute_statistics(
medrecord: MedRecord, attribute_query: NodeQuery
) -> TemporalAttributeInfo: ...
def get_categorical_attribute_statistics(
medrecord: MedRecord, attribute_query: NodeQuery
) -> StringAttributeInfo: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import List, Tuple

from numpy.typing import ArrayLike

from medmodels.medrecord.schema import AttributeType
from medmodels.statistic_evaluations.evaluate_compare.compare import TestSummary

def normal_distribution_test(sample: ArrayLike) -> bool: ...
def decide_hypothesis_test(
samples: List[ArrayLike], attribute_type: AttributeType, alpha: float
) -> TestSummary: ...
def two_tailed_t_test(samples: List[ArrayLike], alpha: float) -> TestSummary: ...
def mann_whitney_u_test(samples: List[ArrayLike], alpha: float) -> TestSummary: ...
def analysis_of_variance(samples: List[ArrayLike], alpha: float) -> TestSummary: ...
def chi_square_independece_test(
samples: List[ArrayLike], alpha: float
) -> TestSummary: ...
def measure_effect_size(samples: List[ArrayLike]) -> Tuple[str, float]: ...
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"pandas>=2.2.2",
"polars[pandas]>=1.6.0",
"scikit-learn>=1.5.0",
"scipy>=1.9.0",
]

[project.optional-dependencies]
Expand Down