Skip to content

Commit

Permalink
refactor: implement new querying interface (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
JabobKrauskopf committed Oct 11, 2024
1 parent 10ebca6 commit 8000ce1
Show file tree
Hide file tree
Showing 12 changed files with 577 additions and 3,099 deletions.
14 changes: 7 additions & 7 deletions medmodels/medrecord/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
)
from medmodels.medrecord.medrecord import (
EdgeIndex,
EdgeOperation,
EdgeQuery,
MedRecord,
NodeIndex,
NodeOperation,
NodeQuery,
)
from medmodels.medrecord.querying import edge, node
from medmodels.medrecord.querying import EdgeOperand, NodeOperand
from medmodels.medrecord.schema import AttributeType, GroupSchema, Schema

__all__ = [
Expand All @@ -33,10 +33,10 @@
"AttributeType",
"Schema",
"GroupSchema",
"node",
"edge",
"NodeIndex",
"EdgeIndex",
"NodeOperation",
"EdgeOperation",
"EdgeQuery",
"NodeQuery",
"NodeOperand",
"EdgeOperand",
]
86 changes: 43 additions & 43 deletions medmodels/medrecord/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Dict, Tuple, Union, overload

from medmodels.medrecord.querying import EdgeOperation, NodeOperation
from medmodels.medrecord.querying import EdgeQuery, NodeQuery
from medmodels.medrecord.types import (
Attributes,
AttributesInput,
Expand Down Expand Up @@ -48,10 +48,10 @@ def __getitem__(
self,
key: Union[
NodeIndexInputList,
NodeOperation,
NodeQuery,
slice,
Tuple[
Union[NodeIndexInputList, NodeOperation, slice],
Union[NodeIndexInputList, NodeQuery, slice],
Union[MedRecordAttributeInputList, slice],
],
],
Expand All @@ -60,18 +60,18 @@ def __getitem__(
@overload
def __getitem__(
self,
key: Tuple[Union[NodeIndexInputList, NodeOperation, slice], MedRecordAttribute],
key: Tuple[Union[NodeIndexInputList, NodeQuery, slice], MedRecordAttribute],
) -> Dict[NodeIndex, MedRecordValue]: ...

def __getitem__(
self,
key: Union[
NodeIndex,
NodeIndexInputList,
NodeOperation,
NodeQuery,
slice,
Tuple[
Union[NodeIndex, NodeIndexInputList, NodeOperation, slice],
Union[NodeIndex, NodeIndexInputList, NodeQuery, slice],
Union[MedRecordAttribute, MedRecordAttributeInputList, slice],
],
],
Expand All @@ -87,7 +87,7 @@ def __getitem__(
if isinstance(key, list):
return self._medrecord._medrecord.node(key)

if isinstance(key, NodeOperation):
if isinstance(key, NodeQuery):
return self._medrecord._medrecord.node(self._medrecord.select_nodes(key))

if isinstance(key, slice):
Expand All @@ -112,7 +112,7 @@ def __getitem__(

return {x: attributes[x][attribute_selection] for x in attributes.keys()}

if isinstance(index_selection, NodeOperation) and is_medrecord_attribute(
if isinstance(index_selection, NodeQuery) and is_medrecord_attribute(
attribute_selection
):
attributes = self._medrecord._medrecord.node(
Expand Down Expand Up @@ -151,7 +151,7 @@ def __getitem__(
for x in attributes.keys()
}

if isinstance(index_selection, NodeOperation) and isinstance(
if isinstance(index_selection, NodeQuery) and isinstance(
attribute_selection, list
):
attributes = self._medrecord._medrecord.node(
Expand Down Expand Up @@ -198,7 +198,7 @@ def __getitem__(

return self._medrecord._medrecord.node(index_selection)

if isinstance(index_selection, NodeOperation) and isinstance(
if isinstance(index_selection, NodeQuery) and isinstance(
attribute_selection, slice
):
if (
Expand Down Expand Up @@ -230,15 +230,15 @@ def __getitem__(
@overload
def __setitem__(
self,
key: Union[NodeIndex, NodeIndexInputList, NodeOperation, slice],
key: Union[NodeIndex, NodeIndexInputList, NodeQuery, slice],
value: AttributesInput,
) -> None: ...

@overload
def __setitem__(
self,
key: Tuple[
Union[NodeIndex, NodeIndexInputList, NodeOperation, slice],
Union[NodeIndex, NodeIndexInputList, NodeQuery, slice],
Union[MedRecordAttribute, MedRecordAttributeInputList, slice],
],
value: MedRecordValue,
Expand All @@ -249,10 +249,10 @@ def __setitem__(
key: Union[
NodeIndex,
NodeIndexInputList,
NodeOperation,
NodeQuery,
slice,
Tuple[
Union[NodeIndex, NodeIndexInputList, NodeOperation, slice],
Union[NodeIndex, NodeIndexInputList, NodeQuery, slice],
Union[MedRecordAttribute, MedRecordAttributeInputList, slice],
],
],
Expand All @@ -270,7 +270,7 @@ def __setitem__(

return self._medrecord._medrecord.replace_node_attributes(key, value)

if isinstance(key, NodeOperation):
if isinstance(key, NodeQuery):
if not is_attributes(value):
raise ValueError("Invalid value type. Expected Attributes")

Expand Down Expand Up @@ -311,7 +311,7 @@ def __setitem__(
index_selection, attribute_selection, value
)

if isinstance(index_selection, NodeOperation) and is_medrecord_attribute(
if isinstance(index_selection, NodeQuery) and is_medrecord_attribute(
attribute_selection
):
if not is_medrecord_value(value):
Expand Down Expand Up @@ -364,7 +364,7 @@ def __setitem__(

return

if isinstance(index_selection, NodeOperation) and isinstance(
if isinstance(index_selection, NodeQuery) and isinstance(
attribute_selection, list
):
if not is_medrecord_value(value):
Expand Down Expand Up @@ -440,7 +440,7 @@ def __setitem__(

return

if isinstance(index_selection, NodeOperation) and isinstance(
if isinstance(index_selection, NodeQuery) and isinstance(
attribute_selection, slice
):
if (
Expand Down Expand Up @@ -494,7 +494,7 @@ def __setitem__(
def __delitem__(
self,
key: Tuple[
Union[NodeIndex, NodeIndexInputList, NodeOperation, slice],
Union[NodeIndex, NodeIndexInputList, NodeQuery, slice],
Union[MedRecordAttribute, MedRecordAttributeInputList, slice],
],
) -> None:
Expand All @@ -514,7 +514,7 @@ def __delitem__(
index_selection, attribute_selection
)

if isinstance(index_selection, NodeOperation) and is_medrecord_attribute(
if isinstance(index_selection, NodeQuery) and is_medrecord_attribute(
attribute_selection
):
return self._medrecord._medrecord.remove_node_attribute(
Expand Down Expand Up @@ -553,7 +553,7 @@ def __delitem__(

return

if isinstance(index_selection, NodeOperation) and isinstance(
if isinstance(index_selection, NodeQuery) and isinstance(
attribute_selection, list
):
for attribute in attribute_selection:
Expand Down Expand Up @@ -602,7 +602,7 @@ def __delitem__(
index_selection, {}
)

if isinstance(index_selection, NodeOperation) and isinstance(
if isinstance(index_selection, NodeQuery) and isinstance(
attribute_selection, slice
):
if (
Expand Down Expand Up @@ -658,10 +658,10 @@ def __getitem__(
self,
key: Union[
EdgeIndexInputList,
EdgeOperation,
EdgeQuery,
slice,
Tuple[
Union[EdgeIndexInputList, EdgeOperation, slice],
Union[EdgeIndexInputList, EdgeQuery, slice],
Union[MedRecordAttributeInputList, slice],
],
],
Expand All @@ -670,18 +670,18 @@ def __getitem__(
@overload
def __getitem__(
self,
key: Tuple[Union[EdgeIndexInputList, EdgeOperation, slice], MedRecordAttribute],
key: Tuple[Union[EdgeIndexInputList, EdgeQuery, slice], MedRecordAttribute],
) -> Dict[EdgeIndex, MedRecordValue]: ...

def __getitem__(
self,
key: Union[
EdgeIndex,
EdgeIndexInputList,
EdgeOperation,
EdgeQuery,
slice,
Tuple[
Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice],
Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice],
Union[MedRecordAttribute, MedRecordAttributeInputList, slice],
],
],
Expand All @@ -697,7 +697,7 @@ def __getitem__(
if isinstance(key, list):
return self._medrecord._medrecord.edge(key)

if isinstance(key, EdgeOperation):
if isinstance(key, EdgeQuery):
return self._medrecord._medrecord.edge(self._medrecord.select_edges(key))

if isinstance(key, slice):
Expand All @@ -722,7 +722,7 @@ def __getitem__(

return {x: attributes[x][attribute_selection] for x in attributes.keys()}

if isinstance(index_selection, EdgeOperation) and is_medrecord_attribute(
if isinstance(index_selection, EdgeQuery) and is_medrecord_attribute(
attribute_selection
):
attributes = self._medrecord._medrecord.edge(
Expand Down Expand Up @@ -761,7 +761,7 @@ def __getitem__(
for x in attributes.keys()
}

if isinstance(index_selection, EdgeOperation) and isinstance(
if isinstance(index_selection, EdgeQuery) and isinstance(
attribute_selection, list
):
attributes = self._medrecord._medrecord.edge(
Expand Down Expand Up @@ -808,7 +808,7 @@ def __getitem__(

return self._medrecord._medrecord.edge(index_selection)

if isinstance(index_selection, EdgeOperation) and isinstance(
if isinstance(index_selection, EdgeQuery) and isinstance(
attribute_selection, slice
):
if (
Expand Down Expand Up @@ -840,15 +840,15 @@ def __getitem__(
@overload
def __setitem__(
self,
key: Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice],
key: Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice],
value: AttributesInput,
) -> None: ...

@overload
def __setitem__(
self,
key: Tuple[
Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice],
Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice],
Union[MedRecordAttribute, MedRecordAttributeInputList, slice],
],
value: MedRecordValue,
Expand All @@ -859,10 +859,10 @@ def __setitem__(
key: Union[
EdgeIndex,
EdgeIndexInputList,
EdgeOperation,
EdgeQuery,
slice,
Tuple[
Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice],
Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice],
Union[MedRecordAttribute, MedRecordAttributeInputList, slice],
],
],
Expand All @@ -880,7 +880,7 @@ def __setitem__(

return self._medrecord._medrecord.replace_edge_attributes(key, value)

if isinstance(key, EdgeOperation):
if isinstance(key, EdgeQuery):
if not is_attributes(value):
raise ValueError("Invalid value type. Expected Attributes")

Expand Down Expand Up @@ -921,7 +921,7 @@ def __setitem__(
index_selection, attribute_selection, value
)

if isinstance(index_selection, EdgeOperation) and is_medrecord_attribute(
if isinstance(index_selection, EdgeQuery) and is_medrecord_attribute(
attribute_selection
):
if not is_medrecord_value(value):
Expand Down Expand Up @@ -974,7 +974,7 @@ def __setitem__(

return

if isinstance(index_selection, EdgeOperation) and isinstance(
if isinstance(index_selection, EdgeQuery) and isinstance(
attribute_selection, list
):
if not is_medrecord_value(value):
Expand Down Expand Up @@ -1048,7 +1048,7 @@ def __setitem__(

return

if isinstance(index_selection, EdgeOperation) and isinstance(
if isinstance(index_selection, EdgeQuery) and isinstance(
attribute_selection, slice
):
if (
Expand Down Expand Up @@ -1102,7 +1102,7 @@ def __setitem__(
def __delitem__(
self,
key: Tuple[
Union[EdgeIndex, EdgeIndexInputList, EdgeOperation, slice],
Union[EdgeIndex, EdgeIndexInputList, EdgeQuery, slice],
Union[MedRecordAttribute, MedRecordAttributeInputList, slice],
],
) -> None:
Expand All @@ -1122,7 +1122,7 @@ def __delitem__(
index_selection, attribute_selection
)

if isinstance(index_selection, EdgeOperation) and is_medrecord_attribute(
if isinstance(index_selection, EdgeQuery) and is_medrecord_attribute(
attribute_selection
):
return self._medrecord._medrecord.remove_edge_attribute(
Expand Down Expand Up @@ -1161,7 +1161,7 @@ def __delitem__(

return

if isinstance(index_selection, EdgeOperation) and isinstance(
if isinstance(index_selection, EdgeQuery) and isinstance(
attribute_selection, list
):
for attribute in attribute_selection:
Expand Down Expand Up @@ -1210,7 +1210,7 @@ def __delitem__(
index_selection, {}
)

if isinstance(index_selection, EdgeOperation) and isinstance(
if isinstance(index_selection, EdgeQuery) and isinstance(
attribute_selection, slice
):
if (
Expand Down
Loading

0 comments on commit 8000ce1

Please sign in to comment.