From d19dc011943f05325c49f070ed86e4aeca6719a6 Mon Sep 17 00:00:00 2001 From: FloLimebit Date: Mon, 26 Aug 2024 10:27:12 +0200 Subject: [PATCH 1/7] intermediate state --- docs/api/treatment_effect.md | 8 +- docs/conf.py | 2 + docs/developer_guide/example_docstrings.py | 292 +++++++++++++++++++++ pyproject.toml | 14 + 4 files changed, 312 insertions(+), 4 deletions(-) create mode 100644 docs/developer_guide/example_docstrings.py diff --git a/docs/api/treatment_effect.md b/docs/api/treatment_effect.md index 56af0814..a8c47818 100644 --- a/docs/api/treatment_effect.md +++ b/docs/api/treatment_effect.md @@ -7,9 +7,9 @@ The `TreatmentEffect` module is designed for analyzing treatment effects within :toctree: _autosummary :recursive: - medmodels.treatment_effect.treatment_effect.TreatmentEffect - medmodels.treatment_effect.builder.TreatmentEffectBuilder - medmodels.treatment_effect.estimate.Estimate - medmodels.treatment_effect.report.Report + medmodels.treatment_effect.treatment_effect + medmodels.treatment_effect.builder + medmodels.treatment_effect.estimate + medmodels.treatment_effect.report medmodels.treatment_effect.temporal_analysis ``` diff --git a/docs/conf.py b/docs/conf.py index 89ee37a0..331789f6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,6 +29,7 @@ "sphinx_togglebutton", "sphinx_multiversion", "sphinx.ext.extlinks", + "sphinx.ext.coverage", ] exclude_patterns = ["_build"] @@ -69,6 +70,7 @@ "private-members": False, "inherited-members": True, "show-inheritance": True, + "ignore-module-all": False, } autosummary_generate = True diff --git a/docs/developer_guide/example_docstrings.py b/docs/developer_guide/example_docstrings.py new file mode 100644 index 00000000..b95a21a2 --- /dev/null +++ b/docs/developer_guide/example_docstrings.py @@ -0,0 +1,292 @@ +"""Example Google style docstrings. + +This module demonstrates documentation as specified by the `Google Python +Style Guide`_. Docstrings may extend over multiple lines. Sections are created +with a section header and a colon followed by a block of indented text. + +Example: + Examples can be given using either the ``Example`` or ``Examples`` + sections. Sections support any reStructuredText formatting, including + literal blocks:: + + $ python example_google.py + +Section breaks are created by resuming unindented text. Section breaks +are also implicitly created anytime a new section starts. + +Attributes: + module_level_variable1 (int): Module level variables may be documented in + either the ``Attributes`` section of the module docstring, or in an + inline docstring immediately following the variable. + + Either form is acceptable, but the two should not be mixed. Choose + one convention to document module level variables and be consistent + with it. + +Todo: + * For module TODOs + * You have to also use ``sphinx.ext.todo`` extension + +.. _Google Python Style Guide: + https://google.github.io/styleguide/pyguide.html + +""" + +module_level_variable1 = 12345 + +module_level_variable2 = 98765 +"""int: Module level variable documented inline. + +The docstring may span multiple lines. The type may optionally be specified +on the first line, separated by a colon. +""" + + +def function(param1: int, param2: int) -> bool: + """Example function with PEP 484 type annotations. + + Args: + param1: The first parameter. + param2: The second parameter. + + Returns: + The return value. True for success, False otherwise. + + """ + return param1 == param2 + + +def module_level_function(param1: int, param2: int, *args, **kwargs): + """This is an example of a module level function. + + Function parameters should be documented in the ``Args`` section. The name + of each parameter is required. The type and description of each parameter + is optional, but should be included if not obvious. + + If ``*args`` or ``**kwargs`` are accepted, + they should be listed as ``*args`` and ``**kwargs``. + + The format for a parameter is:: + + name (type): description + The description may span multiple lines. Following + lines should be indented. The "(type)" is optional. + + Multiple paragraphs are supported in parameter + descriptions. + + Args: + param1 (int): The first parameter. + param2 (:obj:`str`, optional): The second parameter. Defaults to None. + Second line of description should be indented. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + bool: True if successful, False otherwise. + + The return type is optional and may be specified at the beginning of + the ``Returns`` section followed by a colon. + + The ``Returns`` section may span multiple lines and paragraphs. + Following lines should be indented to match the first line. + + The ``Returns`` section supports any reStructuredText formatting, + including literal blocks:: + + {"param1": param1, "param2": param2} + + Raises: + AttributeError: The ``Raises`` section is a list of all exceptions + that are relevant to the interface. + ValueError: If `param2` is equal to `param1`. + + """ + if param1 == param2: + raise ValueError("param1 may not be equal to param2") + return True + + +def example_generator(n): + """Generators have a ``Yields`` section instead of a ``Returns`` section. + + Args: + n (int): The upper limit of the range to generate, from 0 to `n` - 1. + + Yields: + int: The next number in the range of 0 to `n` - 1. + + Examples: + Examples should be written in doctest format, and should illustrate how + to use the function. + + >>> print([i for i in example_generator(4)]) + [0, 1, 2, 3] + + """ + yield from range(n) + + +class ExampleError(Exception): + """Exceptions are documented in the same way as classes. + + The __init__ method may be documented in either the class level + docstring, or as a docstring on the __init__ method itself. + + Either form is acceptable, but the two should not be mixed. Choose one + convention to document the __init__ method and be consistent with it. + + Note: + Do not include the `self` parameter in the ``Args`` section. + + Args: + msg (str): Human readable string describing the exception. + code (:obj:`int`, optional): Error code. + + Attributes: + msg (str): Human readable string describing the exception. + code (int): Exception error code. + + """ + + def __init__(self, msg, code): + self.msg = msg + self.code = code + + +class ExampleClass: + """The summary line for a class docstring should fit on one line. + + If the class has public attributes, they may be documented here + in an ``Attributes`` section and follow the same formatting as a + function's ``Args`` section. Alternatively, attributes may be documented + inline with the attribute's declaration (see __init__ method below). + + Properties created with the ``@property`` decorator should be documented + in the property's getter method. + + Attributes: + attr1 (str): Description of `attr1`. + attr2 (:obj:`int`, optional): Description of `attr2`. + + """ + + def __init__(self, param1, param2, param3): + """Example of docstring on the __init__ method. + + The __init__ method may be documented in either the class level + docstring, or as a docstring on the __init__ method itself. + + Either form is acceptable, but the two should not be mixed. Choose one + convention to document the __init__ method and be consistent with it. + + Note: + Do not include the `self` parameter in the ``Args`` section. + + Args: + param1 (str): Description of `param1`. + param2 (:obj:`int`, optional): Description of `param2`. Multiple + lines are supported. + param3 (list(str)): Description of `param3`. + + """ + self.attr1 = param1 + self.attr2 = param2 + self.attr3 = param3 #: Doc comment *inline* with attribute + + #: list(str): Doc comment *before* attribute, with type specified + self.attr4 = ["attr4"] + + self.attr5 = None + """str: Docstring *after* attribute, with type specified.""" + + @property + def readonly_property(self): + """str: Properties should be documented in their getter method.""" + return "readonly_property" + + @property + def readwrite_property(self): + """list(str): Properties with both a getter and setter + should only be documented in their getter method. + + If the setter method contains notable behavior, it should be + mentioned here. + """ + return ["readwrite_property"] + + @readwrite_property.setter + def readwrite_property(self, value): + value + + def example_method(self, param1, param2): + """Class methods are similar to regular functions. + + Note: + Do not include the `self` parameter in the ``Args`` section. + + Args: + param1: The first parameter. + param2: The second parameter. + + Returns: + True if successful, False otherwise. + + """ + return True + + def __special__(self): + """By default special members with docstrings are not included. + + Special members are any methods or attributes that start with and + end with a double underscore. Any special member with a docstring + will be included in the output, if + ``napoleon_include_special_with_doc`` is set to True. + + This behavior can be enabled by changing the following setting in + Sphinx's conf.py:: + + napoleon_include_special_with_doc = True + + """ + pass + + def __special_without_docstring__(self): + pass + + def _private(self): + """By default private members are not included. + + Private members are any methods or attributes that start with an + underscore and are *not* special. By default they are not included + in the output. + + This behavior can be changed such that private members *are* included + by changing the following setting in Sphinx's conf.py:: + + napoleon_include_private_with_doc = True + + """ + pass + + def _private_without_docstring(self): + pass + + +class ExamplePEP526Class: + """The summary line for a class docstring should fit on one line. + + If the class has public attributes, they may be documented here + in an ``Attributes`` section and follow the same formatting as a + function's ``Args`` section. If ``napoleon_attr_annotations`` + is True, types can be specified in the class body using ``PEP 526`` + annotations. + + Attributes: + attr1: Description of `attr1`. + attr2: Description of `attr2`. + + """ + + attr1: str + attr2: int diff --git a/pyproject.toml b/pyproject.toml index 116b3eb6..fde95ac8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,20 @@ exclude = [ ] line-length = 88 +[tool.ruff.lint] +select = ["PLR"] +preview = true + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.pycodestyle] +max-doc-length = 88 + +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 88 + [tool.pyright] typeCheckingMode = "strict" reportPrivateUsage = false From 2b13a4a4d8d123fff466e6f567a2967916d43358 Mon Sep 17 00:00:00 2001 From: FloLimebit Date: Tue, 27 Aug 2024 16:11:12 +0200 Subject: [PATCH 2/7] #161 new linting rules and auto-fix on full codebase --- docs/serve_docs.py | 16 ++-- medmodels/medrecord/__init__.py | 26 +++--- medmodels/medrecord/datatype.py | 23 +++-- medmodels/medrecord/indexers.py | 48 +++++----- medmodels/medrecord/medrecord.py | 30 +++--- medmodels/medrecord/querying.py | 9 +- medmodels/medrecord/schema.py | 92 +++++++------------ .../treatment_effect/matching/matching.py | 4 +- pyproject.toml | 52 ++++++++++- 9 files changed, 158 insertions(+), 142 deletions(-) diff --git a/docs/serve_docs.py b/docs/serve_docs.py index c2c2b7ce..b5952a02 100644 --- a/docs/serve_docs.py +++ b/docs/serve_docs.py @@ -1,15 +1,17 @@ #!/usr/bin/env python3 +import logging import os import subprocess from pathlib import Path from livereload import Server, shell +logging.basicConfig(level=logging.INFO) -def setup_live_docs_server(): - """ - Set up and run a live documentation server. + +def setup_live_docs_server() -> None: + """Set up and run a live documentation server. This function initializes a server that automatically rebuilds and refreshes the documentation when source files are modified. @@ -25,7 +27,7 @@ def setup_live_docs_server(): rebuild_cmd = shell("make html", cwd=str(script_path)) # Initially build the documentation - print("Building initial documentation...") + logging.info("Building initial documentation...") subprocess.run(["make", "html"], cwd=str(script_path), check=True) # File patterns to watch for changes @@ -48,7 +50,7 @@ def setup_live_docs_server(): if __name__ == "__main__": - print("Starting live documentation server...") - print("Access the docs at http://localhost:5500") - print("Press Ctrl+C to stop the server.") + logging.info("Starting live documentation server...") + logging.info("Access the docs at http://localhost:5500") + logging.info("Press Ctrl+C to stop the server.") setup_live_docs_server() diff --git a/medmodels/medrecord/__init__.py b/medmodels/medrecord/__init__.py index 3a9f3f6f..01913f8f 100644 --- a/medmodels/medrecord/__init__.py +++ b/medmodels/medrecord/__init__.py @@ -20,23 +20,23 @@ from medmodels.medrecord.schema import AttributeType, GroupSchema, Schema __all__ = [ - "MedRecord", - "String", - "Int", - "Float", + "Any", + "AttributeType", "Bool", "DateTime", + "EdgeIndex", + "EdgeOperation", + "Float", + "GroupSchema", + "Int", + "MedRecord", + "NodeIndex", + "NodeOperation", "Null", - "Any", - "Union", "Option", - "AttributeType", "Schema", - "GroupSchema", - "node", + "String", + "Union", "edge", - "NodeIndex", - "EdgeIndex", - "NodeOperation", - "EdgeOperation", + "node", ] diff --git a/medmodels/medrecord/datatype.py b/medmodels/medrecord/datatype.py index 2dcfdea3..28241bd2 100644 --- a/medmodels/medrecord/datatype.py +++ b/medmodels/medrecord/datatype.py @@ -1,7 +1,7 @@ from __future__ import annotations import typing -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from medmodels._medmodels import ( PyAny, @@ -36,7 +36,7 @@ ] -class DataType(metaclass=ABCMeta): +class DataType(ABC): @abstractmethod def _inner(self) -> PyDataType: ... @@ -53,25 +53,24 @@ def __eq__(self, value: object) -> bool: ... def _from_pydatatype(datatype: PyDataType) -> DataType: if isinstance(datatype, PyString): return String() - elif isinstance(datatype, PyInt): + if isinstance(datatype, PyInt): return Int() - elif isinstance(datatype, PyFloat): + if isinstance(datatype, PyFloat): return Float() - elif isinstance(datatype, PyBool): + if isinstance(datatype, PyBool): return Bool() - elif isinstance(datatype, PyDateTime): + if isinstance(datatype, PyDateTime): return DateTime() - elif isinstance(datatype, PyNull): + if isinstance(datatype, PyNull): return Null() - elif isinstance(datatype, PyAny): + if isinstance(datatype, PyAny): return Any() - elif isinstance(datatype, PyUnion): + if isinstance(datatype, PyUnion): return Union( DataType._from_pydatatype(datatype.dtype1), DataType._from_pydatatype(datatype.dtype2), ) - else: - return Option(DataType._from_pydatatype(datatype.dtype)) + return Option(DataType._from_pydatatype(datatype.dtype)) class String(DataType): @@ -213,7 +212,7 @@ class Union(DataType): def __init__(self, *dtypes: DataType) -> None: if len(dtypes) < 2: raise ValueError("Union must have at least two arguments") - elif len(dtypes) == 2: + if len(dtypes) == 2: self._union = PyUnion(dtypes[0]._inner(), dtypes[1]._inner()) else: self._union = PyUnion(dtypes[0]._inner(), Union(*dtypes[1:])._inner()) diff --git a/medmodels/medrecord/indexers.py b/medmodels/medrecord/indexers.py index b76404e1..5fbe808c 100644 --- a/medmodels/medrecord/indexers.py +++ b/medmodels/medrecord/indexers.py @@ -351,7 +351,7 @@ def __setitem__( [index_selection], attribute, value ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, list): if not is_medrecord_value(value): @@ -362,7 +362,7 @@ def __setitem__( index_selection, attribute, value ) - return + return None if isinstance(index_selection, NodeOperation) and isinstance( attribute_selection, list @@ -375,7 +375,7 @@ def __setitem__( self._medrecord.select_nodes(index_selection), attribute, value ) - return + return None if isinstance(index_selection, slice) and isinstance(attribute_selection, list): if ( @@ -393,7 +393,7 @@ def __setitem__( self._medrecord.nodes, attribute, value ) - return + return None if is_node_index(index_selection) and isinstance(attribute_selection, slice): if ( @@ -417,7 +417,7 @@ def __setitem__( value, ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, slice): if ( @@ -438,7 +438,7 @@ def __setitem__( [node], attribute, value ) - return + return None if isinstance(index_selection, NodeOperation) and isinstance( attribute_selection, slice @@ -463,7 +463,7 @@ def __setitem__( [node], attribute, value ) - return + return None if isinstance(index_selection, slice) and isinstance( attribute_selection, slice @@ -489,7 +489,7 @@ def __setitem__( [node], attribute, value ) - return + return None def __delitem__( self, @@ -543,7 +543,7 @@ def __delitem__( [index_selection], attribute ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, list): for attribute in attribute_selection: @@ -551,7 +551,7 @@ def __delitem__( index_selection, attribute ) - return + return None if isinstance(index_selection, NodeOperation) and isinstance( attribute_selection, list @@ -561,7 +561,7 @@ def __delitem__( self._medrecord.select_nodes(index_selection), attribute ) - return + return None if isinstance(index_selection, slice) and isinstance(attribute_selection, list): if ( @@ -576,7 +576,7 @@ def __delitem__( self._medrecord.nodes, attribute ) - return + return None if is_node_index(index_selection) and isinstance(attribute_selection, slice): if ( @@ -961,7 +961,7 @@ def __setitem__( [index_selection], attribute, value ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, list): if not is_medrecord_value(value): @@ -972,7 +972,7 @@ def __setitem__( index_selection, attribute, value ) - return + return None if isinstance(index_selection, EdgeOperation) and isinstance( attribute_selection, list @@ -985,7 +985,7 @@ def __setitem__( self._medrecord.select_edges(index_selection), attribute, value ) - return + return None if isinstance(index_selection, slice) and isinstance(attribute_selection, list): if ( @@ -1003,7 +1003,7 @@ def __setitem__( self._medrecord.edges, attribute, value ) - return + return None if is_edge_index(index_selection) and isinstance(attribute_selection, slice): if ( @@ -1025,7 +1025,7 @@ def __setitem__( [index_selection], attribute, value ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, slice): if ( @@ -1046,7 +1046,7 @@ def __setitem__( [edge], attribute, value ) - return + return None if isinstance(index_selection, EdgeOperation) and isinstance( attribute_selection, slice @@ -1071,7 +1071,7 @@ def __setitem__( [edge], attribute, value ) - return + return None if isinstance(index_selection, slice) and isinstance( attribute_selection, slice @@ -1097,7 +1097,7 @@ def __setitem__( [edge], attribute, value ) - return + return None def __delitem__( self, @@ -1151,7 +1151,7 @@ def __delitem__( [index_selection], attribute ) - return + return None if isinstance(index_selection, list) and isinstance(attribute_selection, list): for attribute in attribute_selection: @@ -1159,7 +1159,7 @@ def __delitem__( index_selection, attribute ) - return + return None if isinstance(index_selection, EdgeOperation) and isinstance( attribute_selection, list @@ -1169,7 +1169,7 @@ def __delitem__( self._medrecord.select_edges(index_selection), attribute ) - return + return None if isinstance(index_selection, slice) and isinstance(attribute_selection, list): if ( @@ -1184,7 +1184,7 @@ def __delitem__( self._medrecord.edges, attribute ) - return + return None if is_edge_index(index_selection) and isinstance(attribute_selection, slice): if ( diff --git a/medmodels/medrecord/medrecord.py b/medmodels/medrecord/medrecord.py index 8446eae8..1820366b 100644 --- a/medmodels/medrecord/medrecord.py +++ b/medmodels/medrecord/medrecord.py @@ -526,11 +526,10 @@ def edges_connecting( (source_node if isinstance(source_node, list) else [source_node]), (target_node if isinstance(target_node, list) else [target_node]), ) - else: - return self._medrecord.edges_connecting_undirected( - (source_node if isinstance(source_node, list) else [source_node]), - (target_node if isinstance(target_node, list) else [target_node]), - ) + return self._medrecord.edges_connecting_undirected( + (source_node if isinstance(source_node, list) else [source_node]), + (target_node if isinstance(target_node, list) else [target_node]), + ) def add_node(self, node: NodeIndex, attributes: AttributesInput) -> None: """Adds a node with specified attributes to the MedRecord instance. @@ -609,12 +608,11 @@ def add_nodes( nodes ): return self.add_nodes_pandas(nodes) - elif is_polars_node_dataframe_input( + if is_polars_node_dataframe_input(nodes) or is_polars_node_dataframe_input_list( nodes - ) or is_polars_node_dataframe_input_list(nodes): + ): return self.add_nodes_polars(nodes) - else: - return self._medrecord.add_nodes(nodes) + return self._medrecord.add_nodes(nodes) def add_nodes_pandas( self, nodes: Union[PandasNodeDataFrameInput, List[PandasNodeDataFrameInput]] @@ -740,12 +738,11 @@ def add_edges( edges ): return self.add_edges_pandas(edges) - elif is_polars_edge_dataframe_input( + if is_polars_edge_dataframe_input(edges) or is_polars_edge_dataframe_input_list( edges - ) or is_polars_edge_dataframe_input_list(edges): + ): return self.add_edges_polars(edges) - else: - return self._medrecord.add_edges(edges) + return self._medrecord.add_edges(edges) def add_edges_pandas( self, edges: Union[PandasEdgeDataFrameInput, List[PandasEdgeDataFrameInput]] @@ -827,16 +824,15 @@ def add_group( nodes if isinstance(nodes, list) else [nodes], edges if isinstance(edges, list) else [edges], ) - elif nodes is not None: + if nodes is not None: return self._medrecord.add_group( group, nodes if isinstance(nodes, list) else [nodes], None ) - elif edges is not None: + if edges is not None: return self._medrecord.add_group( group, None, edges if isinstance(edges, list) else [edges] ) - else: - return self._medrecord.add_group(group, None, None) + return self._medrecord.add_group(group, None, None) def remove_group(self, group: Union[Group, GroupInputList]) -> None: """Removes one or more groups from the MedRecord instance. diff --git a/medmodels/medrecord/querying.py b/medmodels/medrecord/querying.py index 5a6e8385..7f749578 100644 --- a/medmodels/medrecord/querying.py +++ b/medmodels/medrecord/querying.py @@ -1360,12 +1360,9 @@ def has_neighbor_with( return NodeOperation( self._node_operand.has_neighbor_with(operation._node_operation) ) - else: - return NodeOperation( - self._node_operand.has_neighbor_undirected_with( - operation._node_operation - ) - ) + return NodeOperation( + self._node_operand.has_neighbor_undirected_with(operation._node_operation) + ) def attribute(self, attribute: MedRecordAttribute) -> NodeAttributeOperand: """Accesses an NodeAttributeOperand for the specified attribute, allowing for the creation of operations based on node attributes. diff --git a/medmodels/medrecord/schema.py b/medmodels/medrecord/schema.py index 2dd28890..f7887a73 100644 --- a/medmodels/medrecord/schema.py +++ b/medmodels/medrecord/schema.py @@ -20,8 +20,7 @@ class AttributeType(Enum): @staticmethod def _from_pyattributetype(py_attribute_type: PyAttributeType) -> AttributeType: - """ - Converts a PyAttributeType to an AttributeType. + """Converts a PyAttributeType to an AttributeType. Args: py_attribute_type (PyAttributeType): The PyAttributeType to convert. @@ -33,28 +32,25 @@ def _from_pyattributetype(py_attribute_type: PyAttributeType) -> AttributeType: return AttributeType.Categorical elif py_attribute_type == PyAttributeType.Continuous: return AttributeType.Continuous - elif py_attribute_type == PyAttributeType.Temporal: + if py_attribute_type == PyAttributeType.Temporal: return AttributeType.Temporal def _into_pyattributetype(self) -> PyAttributeType: - """ - Converts an AttributeType to a PyAttributeType. + """Converts an AttributeType to a PyAttributeType. Returns: PyAttributeType: The converted PyAttributeType. """ if self == AttributeType.Categorical: return PyAttributeType.Categorical - elif self == AttributeType.Continuous: + if self == AttributeType.Continuous: return PyAttributeType.Continuous - elif self == AttributeType.Temporal: + if self == AttributeType.Temporal: return PyAttributeType.Temporal - else: - raise NotImplementedError("Should never be reached") + raise NotImplementedError("Should never be reached") def __repr__(self) -> str: - """ - Returns a string representation of the AttributeType instance. + """Returns a string representation of the AttributeType instance. Returns: str: String representation of the attribute type. @@ -62,8 +58,7 @@ def __repr__(self) -> str: return f"AttributeType.{self.name}" def __str__(self) -> str: - """ - Returns a string representation of the AttributeType instance. + """Returns a string representation of the AttributeType instance. Returns: str: String representation of the attribute type. @@ -71,8 +66,7 @@ def __str__(self) -> str: return self.name def __eq__(self, value: object) -> bool: - """ - Compares the AttributeType instance to another object for equality. + """Compares the AttributeType instance to another object for equality. Args: value (object): The object to compare against. @@ -82,7 +76,7 @@ def __eq__(self, value: object) -> bool: """ if isinstance(value, PyAttributeType): return self._into_pyattributetype() == value - elif isinstance(value, AttributeType): + if isinstance(value, AttributeType): return str(self) == str(value) return False @@ -99,8 +93,7 @@ def __init__( MedRecordAttribute, Tuple[DataType, Optional[AttributeType]] ], ) -> None: - """ - Initializes a new instance of AttributesSchema. + """Initializes a new instance of AttributesSchema. Args: attributes_schema (Dict[MedRecordAttribute, Tuple[DataType, Optional[AttributeType]]]): @@ -113,8 +106,7 @@ def __init__( self._attributes_schema = attributes_schema def __repr__(self) -> str: - """ - Returns a string representation of the AttributesSchema instance. + """Returns a string representation of the AttributesSchema instance. Returns: str: String representation of the attribute schema. @@ -124,8 +116,7 @@ def __repr__(self) -> str: def __getitem__( self, key: MedRecordAttribute ) -> Tuple[DataType, Optional[AttributeType]]: - """ - Gets the data type and optional attribute type for a given MedRecordAttribute. + """Gets the data type and optional attribute type for a given MedRecordAttribute. Args: key (MedRecordAttribute): The attribute for which the data type is @@ -138,8 +129,7 @@ def __getitem__( return self._attributes_schema[key] def __contains__(self, key: MedRecordAttribute) -> bool: - """ - Checks if a given MedRecordAttribute is in the attributes schema. + """Checks if a given MedRecordAttribute is in the attributes schema. Args: key (MedRecordAttribute): The attribute to check. @@ -150,8 +140,7 @@ def __contains__(self, key: MedRecordAttribute) -> bool: return key in self._attributes_schema def __iter__(self): - """ - Returns an iterator over the attributes schema. + """Returns an iterator over the attributes schema. Returns: Iterator: An iterator over the attribute keys. @@ -159,8 +148,7 @@ def __iter__(self): return self._attributes_schema.__iter__() def __len__(self) -> int: - """ - Returns the number of attributes in the schema. + """Returns the number of attributes in the schema. Returns: int: The number of attributes. @@ -168,8 +156,7 @@ def __len__(self) -> int: return len(self._attributes_schema) def __eq__(self, value: object) -> bool: - """ - Compares the AttributesSchema instance to another object for equality. + """Compares the AttributesSchema instance to another object for equality. Args: value (object): The object to compare against. @@ -200,8 +187,7 @@ def __eq__(self, value: object) -> bool: return True def keys(self): - """ - Returns the attribute keys in the schema. + """Returns the attribute keys in the schema. Returns: KeysView: A view object displaying a list of dictionary's keys. @@ -209,8 +195,7 @@ def keys(self): return self._attributes_schema.keys() def values(self): - """ - Returns the attribute values in the schema. + """Returns the attribute values in the schema. Returns: ValuesView: A view object displaying a list of dictionary's values. @@ -218,8 +203,7 @@ def values(self): return self._attributes_schema.values() def items(self): - """ - Returns the attribute key-value pairs in the schema. + """Returns the attribute key-value pairs in the schema. Returns: ItemsView: A set-like object providing a view on D's items. @@ -241,8 +225,7 @@ def get( key: MedRecordAttribute, default: Optional[Tuple[DataType, Optional[AttributeType]]] = None, ) -> Optional[Tuple[DataType, Optional[AttributeType]]]: - """ - Gets the data type and optional attribute type for a given attribute, returning + """Gets the data type and optional attribute type for a given attribute, returning a default value if the attribute is not present. Args: @@ -273,8 +256,7 @@ def __init__( ] = {}, strict: bool = False, ) -> None: - """ - Initializes a new instance of GroupSchema. + """Initializes a new instance of GroupSchema. Args: nodes (Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]]): @@ -307,8 +289,7 @@ def _convert_input( @classmethod def _from_pygroupschema(cls, group_schema: PyGroupSchema) -> GroupSchema: - """ - Creates a GroupSchema instance from an existing PyGroupSchema. + """Creates a GroupSchema instance from an existing PyGroupSchema. Args: group_schema (PyGroupSchema): The PyGroupSchema instance to convert. @@ -322,8 +303,7 @@ def _from_pygroupschema(cls, group_schema: PyGroupSchema) -> GroupSchema: @property def nodes(self) -> AttributesSchema: - """ - Returns the node attributes in the GroupSchema instance. + """Returns the node attributes in the GroupSchema instance. Returns: AttributesSchema: An AttributesSchema object containing the node attributes @@ -349,8 +329,7 @@ def _convert_node( @property def edges(self) -> AttributesSchema: - """ - Returns the edge attributes in the GroupSchema instance. + """Returns the edge attributes in the GroupSchema instance. Returns: AttributesSchema: An AttributesSchema object containing the edge attributes @@ -376,8 +355,7 @@ def _convert_edge( @property def strict(self) -> Optional[bool]: - """ - Indicates whether the GroupSchema instance is strict. + """Indicates whether the GroupSchema instance is strict. Returns: Optional[bool]: True if the schema is strict, False otherwise. @@ -395,8 +373,7 @@ def __init__( default: Optional[GroupSchema] = None, strict: bool = False, ) -> None: - """ - Initializes a new instance of Schema. + """Initializes a new instance of Schema. Args: groups (Dict[Group, GroupSchema], optional): A dictionary of group names @@ -423,8 +400,7 @@ def __init__( @classmethod def _from_pyschema(cls, schema: PySchema) -> Schema: - """ - Creates a Schema instance from an existing PySchema. + """Creates a Schema instance from an existing PySchema. Args: schema (PySchema): The PySchema instance to convert. @@ -438,8 +414,7 @@ def _from_pyschema(cls, schema: PySchema) -> Schema: @property def groups(self) -> List[Group]: - """ - Lists all the groups in the Schema instance. + """Lists all the groups in the Schema instance. Returns: List[Group]: A list of groups. @@ -447,8 +422,7 @@ def groups(self) -> List[Group]: return self._schema.groups def group(self, group: Group) -> GroupSchema: - """ - Retrieves the schema for a specific group. + """Retrieves the schema for a specific group. Args: group (Group): The name of the group. @@ -463,8 +437,7 @@ def group(self, group: Group) -> GroupSchema: @property def default(self) -> Optional[GroupSchema]: - """ - Retrieves the default group schema. + """Retrieves the default group schema. Returns: Optional[GroupSchema]: The default group schema if it exists, otherwise @@ -477,8 +450,7 @@ def default(self) -> Optional[GroupSchema]: @property def strict(self) -> Optional[bool]: - """ - Indicates whether the Schema instance is strict. + """Indicates whether the Schema instance is strict. Returns: Optional[bool]: True if the schema is strict, False otherwise. diff --git a/medmodels/treatment_effect/matching/matching.py b/medmodels/treatment_effect/matching/matching.py index 5f276ea5..a396970f 100644 --- a/medmodels/treatment_effect/matching/matching.py +++ b/medmodels/treatment_effect/matching/matching.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Literal, Set, Tuple import polars as pl @@ -19,7 +19,7 @@ MatchingMethod: TypeAlias = Literal["propensity", "nearest_neighbors"] -class Matching(metaclass=ABCMeta): +class Matching(ABC): """The Base Class for matching.""" def _preprocess_data( diff --git a/pyproject.toml b/pyproject.toml index fde95ac8..2f8dcafb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,8 +88,58 @@ exclude = [ line-length = 88 [tool.ruff.lint] -select = ["PLR"] +select = [ + "E", # PEP 8 codestyle errors + "F", # pyflakes + "I", # isort + "N", # PEP 8 naming + "DOC", # Pydoc Linting (preview); complementary to "D" + "D", # Pydoc Style; PEP 257 + "FA", # future annotations linting; PEP 563 + "W", # pycodestyle warnings; PEP 8 + "SIM", # flake8 simplify; simplify code + "ANN", # flake8 function annotations; PEP 3107 + "B", # bugbear extension for flake8; opinionated, not based on any PEP + "C4", # list/set/dict comprehensions + "T10", # Check for debugging leftovers: pdb;idbp imports and set traces + "EM", # error messages + "LOG", # logging module usage linting + "G", # logging format strings + "T20", # print statements + "PYI", # lint stub files .pyi + "PT", # pytest linting + "RET", # return values + "TCH", # type checking + "PTH", # pathlib usage + "PERF", # performance linting + "FURB", # modern python code patterns + "RUF", # ruff specific rules + "FBT", # no bool as function param + "TD", # todo linting + "C90", # mccabe complexity +] preview = true +ignore = [ + "E501", # Line length managed by formatter + # indentation linters conflicting with formatter: + "W191", + "E111", + "E114", + "E117", + "D206", + # quotation linters conflicting with formatter: + "D300", + "Q000", + "Q001", + "Q002", + "Q003", + # comma linters conflicting with formatter: + "COM812", + "COM819", + # string concatenation linters conflicting with formatter: + "ISC001", + "ISC002", +] [tool.ruff.lint.pydocstyle] convention = "google" From bfddc992d28947138f089bf237a23d1ee4b07b54 Mon Sep 17 00:00:00 2001 From: FloLimebit Date: Tue, 27 Aug 2024 16:29:21 +0200 Subject: [PATCH 3/7] #161 ruff unsafe fixes applied --- docs/conf.py | 2 +- docs/developer_guide/docstrings.md | 59 +- docs/developer_guide/example_docstrings.py | 292 +--- medmodels/_medmodels.pyi | 1 - medmodels/medrecord/builder.py | 7 +- medmodels/medrecord/datatype.py | 3 +- medmodels/medrecord/indexers.py | 274 +-- medmodels/medrecord/querying.py | 2 +- medmodels/medrecord/schema.py | 32 +- medmodels/medrecord/tests/test_builder.py | 46 +- medmodels/medrecord/tests/test_datatype.py | 129 +- medmodels/medrecord/tests/test_indexers.py | 1552 ++++------------- medmodels/medrecord/tests/test_medrecord.py | 691 ++++---- medmodels/medrecord/tests/test_querying.py | 1086 ++---------- medmodels/medrecord/tests/test_schema.py | 183 +- medmodels/treatment_effect/builder.py | 36 +- .../treatment_effect/continuous_estimators.py | 12 +- medmodels/treatment_effect/estimate.py | 44 +- .../algorithms/classic_distance_models.py | 3 +- .../matching/algorithms/propensity_score.py | 6 +- .../treatment_effect/matching/matching.py | 6 +- .../treatment_effect/matching/neighbors.py | 18 +- .../treatment_effect/matching/propensity.py | 18 +- .../tests/test_classic_distance_models.py | 23 +- .../matching/tests/test_metrics.py | 22 +- .../matching/tests/test_propensity_score.py | 18 +- medmodels/treatment_effect/report.py | 5 +- .../treatment_effect/temporal_analysis.py | 9 +- .../tests/test_continuous_estimators.py | 46 +- .../tests/test_temporal_analysis.py | 55 +- .../tests/test_treatment_effect.py | 249 +-- .../treatment_effect/treatment_effect.py | 52 +- 32 files changed, 1524 insertions(+), 3457 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 331789f6..dac1fb9d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -172,7 +172,7 @@ # Local Sphinx extensions -def setup(app: Sphinx): +def setup(app: Sphinx) -> None: """Add custom directives and transformations to Sphinx.""" from myst_parser._docs import ( DirectiveDoc, diff --git a/docs/developer_guide/docstrings.md b/docs/developer_guide/docstrings.md index 5668959c..36215c1b 100644 --- a/docs/developer_guide/docstrings.md +++ b/docs/developer_guide/docstrings.md @@ -29,21 +29,14 @@ The summary line provides a brief description of what the function or class does Example: -```python -def example_function(param1, param2): - """This function performs an example operation. - - Args: - param1 (int): The first parameter. - param2 (str): The second parameter. - - Returns: - bool: True if successful, False otherwise. - """ +```{literalinclude} example_docstrings.py +:lines: 1-5, 6-38 +:emphasize-lines: 13 ``` -:::{note} -The summary line cannot contain line breaks. When writing docstrings, the max line length setting can be exceeded without triggering linting errors. +:::{admonition} No Linebreaks In Summary Line +:type: note +The summary line cannot contain line breaks and must not exceed the maximum line length. Ensure that the summary line provides a brief description of the function's purpose. Further explanation and details can be given in the [Description](#description) text block underneath. ::: ### Description @@ -221,26 +214,28 @@ Document parameters under the `Args` section. Each parameter should include its Example: -```python -def example_function(param1, param2, param3): - """This function performs an example operation. - - Args: - param1 (int): The first parameter. - param2 (Union[str, List[str], Dict[str, Any]]): The second parameter, which can be - a string, a list of strings, or a dictionary with string keys and any type of - values. This parameter is used to demonstrate a long type definition. - param3 (str): The third parameter, which has a long description that needs to be - broken into multiple lines for better readability. This parameter is used - to show how to write long descriptions and how to indent them properly in - the docstring. - """ +```{literalinclude} example_docstrings.py +:lines: 1-5, 6-38 +:emphasize-lines: 18-26 ``` -:::{note} -Type definitions of parameters cannot have a line break. You can not put a line break inside `(Union[str, List[str], Dict[str, Any]])`. When writing docstrings, the max line length setting can be exceeded without triggering linting errors. +:::{admonition} No Linebreak in Type Definitions +:type: note + +When writing type definitions in argument docstrings, avoid placing line breaks inside the type annotations. For instance, complex types like `(Union[str, List[str], Dict[str, Any]])` should appear on a single line without splitting across lines. This ensures the type definition remains clear and avoids parsing issues. + +If a type definition exceeds the maximum line length limit, deactivate the line length rule for that doctstring using [`# noqa: E501`](https://docs.astral.sh/ruff/rules/line-too-long/#error-suppression) after the closing `"""`. This keeps the type annotation intact while preventing the linter from flagging the long line as an error. ::: + +:::{admonition} No Boolean Argument Types +:type: note +Avoid using booleans as function arguments due to the "boolean trap." Booleans reduce code clarity, as `True` or `False` doesn't explain meaning. They also limit future flexibility — if more than two options are needed, changing the argument type can cause breaking changes. Instead, use an enum or a more descriptive type for better clarity and flexibility. + +For more information, you can refer to [Adam Johnson's article](https://adamj.eu/tech/2021/07/10/python-type-hints-how-to-avoid-the-boolean-trap/) which discusses this in detail and provides examples of how to avoid the boolean trap. +::: + + ### Return Types Document return types under the `Returns` section. Each return type should include the type and a brief description. @@ -256,6 +251,12 @@ def example_function(param1, param2): """ ``` +:::{admonition} Don't Document `None` Return Value +:type: note +Don't add a `Returns` section when the function has no return value (`def fun() -> None`) +::: + + ### Examples Provide examples under the `Examples` section by using Sphinx's `.. code-block::` directive within the docstrings. diff --git a/docs/developer_guide/example_docstrings.py b/docs/developer_guide/example_docstrings.py index b95a21a2..7d1306a3 100644 --- a/docs/developer_guide/example_docstrings.py +++ b/docs/developer_guide/example_docstrings.py @@ -1,113 +1,44 @@ -"""Example Google style docstrings. +from __future__ import annotations -This module demonstrates documentation as specified by the `Google Python -Style Guide`_. Docstrings may extend over multiple lines. Sections are created -with a section header and a colon followed by a block of indented text. +from typing import Any, Dict, Iterator -Example: - Examples can be given using either the ``Example`` or ``Examples`` - sections. Sections support any reStructuredText formatting, including - literal blocks:: - $ python example_google.py +def example_function_args( + param1: int, + param2: str | int, + optional_param: list[str] | None = None, + *args: float | str, + **kwargs: Dict[str, Any] +) -> tuple[bool, list[str]]: + """Example function with PEP 484 type annotations and PEP 563 future annotations. -Section breaks are created by resuming unindented text. Section breaks -are also implicitly created anytime a new section starts. - -Attributes: - module_level_variable1 (int): Module level variables may be documented in - either the ``Attributes`` section of the module docstring, or in an - inline docstring immediately following the variable. - - Either form is acceptable, but the two should not be mixed. Choose - one convention to document module level variables and be consistent - with it. - -Todo: - * For module TODOs - * You have to also use ``sphinx.ext.todo`` extension - -.. _Google Python Style Guide: - https://google.github.io/styleguide/pyguide.html - -""" - -module_level_variable1 = 12345 - -module_level_variable2 = 98765 -"""int: Module level variable documented inline. - -The docstring may span multiple lines. The type may optionally be specified -on the first line, separated by a colon. -""" - - -def function(param1: int, param2: int) -> bool: - """Example function with PEP 484 type annotations. + This function shows how to define and document typing for different kinds of + arguments, including positional, optional, variable-length args, and keyword args. Args: - param1: The first parameter. - param2: The second parameter. + param1 (int): A required integer parameter. + param2 (str | int): A parameter that can be either a string or an integer. + optional_param (list[str] | None, optional): An optional parameter that accepts + a list of strings. Defaults to None if not provided. + *args (float | str): Variable length argument list that accepts floats or + strings. + **kwargs (Dict[str, Any]): Arbitrary keyword arguments as a dictionary of string + keys and values of any type. Returns: - The return value. True for success, False otherwise. - + tuple[bool, list[str]]: A tuple containing: + - bool: Always True for this example. + - list[str]: A list with a single string describing the received arguments. """ - return param1 == param2 + result = ( + f"Received: param1={param1}, param2={param2}, optional_param={optional_param}, " + f"args={args}, kwargs={kwargs}" + ) + return True, [result] -def module_level_function(param1: int, param2: int, *args, **kwargs): - """This is an example of a module level function. - - Function parameters should be documented in the ``Args`` section. The name - of each parameter is required. The type and description of each parameter - is optional, but should be included if not obvious. - - If ``*args`` or ``**kwargs`` are accepted, - they should be listed as ``*args`` and ``**kwargs``. - - The format for a parameter is:: - - name (type): description - The description may span multiple lines. Following - lines should be indented. The "(type)" is optional. - - Multiple paragraphs are supported in parameter - descriptions. - - Args: - param1 (int): The first parameter. - param2 (:obj:`str`, optional): The second parameter. Defaults to None. - Second line of description should be indented. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - bool: True if successful, False otherwise. - - The return type is optional and may be specified at the beginning of - the ``Returns`` section followed by a colon. - The ``Returns`` section may span multiple lines and paragraphs. - Following lines should be indented to match the first line. - - The ``Returns`` section supports any reStructuredText formatting, - including literal blocks:: - - {"param1": param1, "param2": param2} - - Raises: - AttributeError: The ``Raises`` section is a list of all exceptions - that are relevant to the interface. - ValueError: If `param2` is equal to `param1`. - - """ - if param1 == param2: - raise ValueError("param1 may not be equal to param2") - return True - - -def example_generator(n): +def example_generator(n: int) -> Iterator[int]: """Generators have a ``Yields`` section instead of a ``Returns`` section. Args: @@ -125,168 +56,3 @@ def example_generator(n): """ yield from range(n) - - -class ExampleError(Exception): - """Exceptions are documented in the same way as classes. - - The __init__ method may be documented in either the class level - docstring, or as a docstring on the __init__ method itself. - - Either form is acceptable, but the two should not be mixed. Choose one - convention to document the __init__ method and be consistent with it. - - Note: - Do not include the `self` parameter in the ``Args`` section. - - Args: - msg (str): Human readable string describing the exception. - code (:obj:`int`, optional): Error code. - - Attributes: - msg (str): Human readable string describing the exception. - code (int): Exception error code. - - """ - - def __init__(self, msg, code): - self.msg = msg - self.code = code - - -class ExampleClass: - """The summary line for a class docstring should fit on one line. - - If the class has public attributes, they may be documented here - in an ``Attributes`` section and follow the same formatting as a - function's ``Args`` section. Alternatively, attributes may be documented - inline with the attribute's declaration (see __init__ method below). - - Properties created with the ``@property`` decorator should be documented - in the property's getter method. - - Attributes: - attr1 (str): Description of `attr1`. - attr2 (:obj:`int`, optional): Description of `attr2`. - - """ - - def __init__(self, param1, param2, param3): - """Example of docstring on the __init__ method. - - The __init__ method may be documented in either the class level - docstring, or as a docstring on the __init__ method itself. - - Either form is acceptable, but the two should not be mixed. Choose one - convention to document the __init__ method and be consistent with it. - - Note: - Do not include the `self` parameter in the ``Args`` section. - - Args: - param1 (str): Description of `param1`. - param2 (:obj:`int`, optional): Description of `param2`. Multiple - lines are supported. - param3 (list(str)): Description of `param3`. - - """ - self.attr1 = param1 - self.attr2 = param2 - self.attr3 = param3 #: Doc comment *inline* with attribute - - #: list(str): Doc comment *before* attribute, with type specified - self.attr4 = ["attr4"] - - self.attr5 = None - """str: Docstring *after* attribute, with type specified.""" - - @property - def readonly_property(self): - """str: Properties should be documented in their getter method.""" - return "readonly_property" - - @property - def readwrite_property(self): - """list(str): Properties with both a getter and setter - should only be documented in their getter method. - - If the setter method contains notable behavior, it should be - mentioned here. - """ - return ["readwrite_property"] - - @readwrite_property.setter - def readwrite_property(self, value): - value - - def example_method(self, param1, param2): - """Class methods are similar to regular functions. - - Note: - Do not include the `self` parameter in the ``Args`` section. - - Args: - param1: The first parameter. - param2: The second parameter. - - Returns: - True if successful, False otherwise. - - """ - return True - - def __special__(self): - """By default special members with docstrings are not included. - - Special members are any methods or attributes that start with and - end with a double underscore. Any special member with a docstring - will be included in the output, if - ``napoleon_include_special_with_doc`` is set to True. - - This behavior can be enabled by changing the following setting in - Sphinx's conf.py:: - - napoleon_include_special_with_doc = True - - """ - pass - - def __special_without_docstring__(self): - pass - - def _private(self): - """By default private members are not included. - - Private members are any methods or attributes that start with an - underscore and are *not* special. By default they are not included - in the output. - - This behavior can be changed such that private members *are* included - by changing the following setting in Sphinx's conf.py:: - - napoleon_include_private_with_doc = True - - """ - pass - - def _private_without_docstring(self): - pass - - -class ExamplePEP526Class: - """The summary line for a class docstring should fit on one line. - - If the class has public attributes, they may be documented here - in an ``Attributes`` section and follow the same formatting as a - function's ``Args`` section. If ``napoleon_attr_annotations`` - is True, types can be specified in the class body using ``PEP 526`` - annotations. - - Attributes: - attr1: Description of `attr1`. - attr2: Description of `attr2`. - - """ - - attr1: str - attr2: int diff --git a/medmodels/_medmodels.pyi b/medmodels/_medmodels.pyi index 275ce805..ba4125f4 100644 --- a/medmodels/_medmodels.pyi +++ b/medmodels/_medmodels.pyi @@ -1,4 +1,3 @@ -from __future__ import annotations from enum import Enum from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union diff --git a/medmodels/medrecord/builder.py b/medmodels/medrecord/builder.py index 90c6a9b9..b25bae20 100644 --- a/medmodels/medrecord/builder.py +++ b/medmodels/medrecord/builder.py @@ -5,8 +5,9 @@ if TYPE_CHECKING: from typing_extensions import TypeIs + from medmodels.medrecord.schema import Schema + import medmodels as mm -from medmodels.medrecord.schema import Schema from medmodels.medrecord.types import ( EdgeTuple, Group, @@ -128,7 +129,7 @@ def add_edges( return self def add_group( - self, group: Group, *, nodes: List[NodeIndex] = [] + self, group: Group, *, nodes: Optional[List[NodeIndex]] = None ) -> MedRecordBuilder: """Adds a group to the builder with an optional list of nodes. @@ -139,6 +140,8 @@ def add_group( Returns: MedRecordBuilder: The current instance of the builder. """ + if nodes is None: + nodes = [] self._groups[group] = {"nodes": nodes, "edges": []} return self diff --git a/medmodels/medrecord/datatype.py b/medmodels/medrecord/datatype.py index 28241bd2..daa8f3b2 100644 --- a/medmodels/medrecord/datatype.py +++ b/medmodels/medrecord/datatype.py @@ -211,7 +211,8 @@ class Union(DataType): def __init__(self, *dtypes: DataType) -> None: if len(dtypes) < 2: - raise ValueError("Union must have at least two arguments") + msg = "Union must have at least two arguments" + raise ValueError(msg) if len(dtypes) == 2: self._union = PyUnion(dtypes[0]._inner(), dtypes[1]._inner()) else: diff --git a/medmodels/medrecord/indexers.py b/medmodels/medrecord/indexers.py index 5fbe808c..439e1335 100644 --- a/medmodels/medrecord/indexers.py +++ b/medmodels/medrecord/indexers.py @@ -92,7 +92,8 @@ def __getitem__( if isinstance(key, slice): if key.start is not None or key.stop is not None or key.step is not None: - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.node(self._medrecord.nodes) @@ -110,7 +111,7 @@ def __getitem__( ): attributes = self._medrecord._medrecord.node(index_selection) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if isinstance(index_selection, NodeOperation) and is_medrecord_attribute( attribute_selection @@ -119,7 +120,7 @@ def __getitem__( self._medrecord.select_nodes(index_selection) ) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if isinstance(index_selection, slice) and is_medrecord_attribute( attribute_selection @@ -129,11 +130,12 @@ def __getitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) attributes = self._medrecord._medrecord.node(self._medrecord.nodes) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if is_node_index(index_selection) and isinstance(attribute_selection, list): return { @@ -148,7 +150,7 @@ def __getitem__( return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if isinstance(index_selection, NodeOperation) and isinstance( @@ -160,7 +162,7 @@ def __getitem__( return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if isinstance(index_selection, slice) and isinstance(attribute_selection, list): @@ -169,13 +171,14 @@ def __getitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) attributes = self._medrecord._medrecord.node(self._medrecord.nodes) return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if is_node_index(index_selection) and isinstance(attribute_selection, slice): @@ -184,7 +187,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.node([index_selection])[index_selection] @@ -194,7 +198,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.node(index_selection) @@ -206,7 +211,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.node( self._medrecord.select_nodes(index_selection) @@ -223,9 +229,11 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.node(self._medrecord.nodes) + return None @overload def __setitem__( @@ -260,19 +268,22 @@ def __setitem__( ) -> None: if is_node_index(key): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes([key], value) if isinstance(key, list): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes(key, value) if isinstance(key, NodeOperation): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( self._medrecord.select_nodes(key), value @@ -280,10 +291,12 @@ def __setitem__( if isinstance(key, slice): if key.start is not None or key.stop is not None or key.step is not None: - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( self._medrecord.nodes, value @@ -295,7 +308,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_node_attribute( [index_selection], attribute_selection, value @@ -305,7 +319,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_node_attribute( index_selection, attribute_selection, value @@ -315,7 +330,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_node_attribute( self._medrecord.select_nodes(index_selection), @@ -331,10 +347,12 @@ def __setitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_node_attribute( self._medrecord.nodes, @@ -344,7 +362,8 @@ def __setitem__( if is_node_index(index_selection) and isinstance(attribute_selection, list): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_node_attribute( @@ -355,7 +374,8 @@ def __setitem__( if isinstance(index_selection, list) and isinstance(attribute_selection, list): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_node_attribute( @@ -368,7 +388,8 @@ def __setitem__( attribute_selection, list ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_node_attribute( @@ -383,10 +404,12 @@ def __setitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_node_attribute( @@ -401,16 +424,18 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.node([index_selection])[ index_selection ] - for attribute in attributes.keys(): + for attribute in attributes: self._medrecord._medrecord.update_node_attribute( [index_selection], attribute, @@ -425,15 +450,17 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.node(index_selection) - for node in attributes.keys(): - for attribute in attributes[node].keys(): + for node in attributes: + for attribute in attributes[node]: self._medrecord._medrecord.update_node_attribute( [node], attribute, value ) @@ -448,17 +475,19 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.node( self._medrecord.select_nodes(index_selection) ) - for node in attributes.keys(): - for attribute in attributes[node].keys(): + for node in attributes: + for attribute in attributes[node]: self._medrecord._medrecord.update_node_attribute( [node], attribute, value ) @@ -476,20 +505,23 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.node(self._medrecord.nodes) - for node in attributes.keys(): - for attribute in attributes[node].keys(): + for node in attributes: + for attribute in attributes[node]: self._medrecord._medrecord.update_node_attribute( [node], attribute, value ) return None + return None def __delitem__( self, @@ -530,7 +562,8 @@ def __delitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.remove_node_attribute( self._medrecord.nodes, @@ -569,7 +602,8 @@ def __delitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.remove_node_attribute( @@ -584,7 +618,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( [index_selection], {} @@ -596,7 +631,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( index_selection, {} @@ -610,7 +646,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( self._medrecord.select_nodes(index_selection), {} @@ -627,11 +664,13 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_node_attributes( self._medrecord.nodes, {} ) + return None class EdgeIndexer: @@ -702,7 +741,8 @@ def __getitem__( if isinstance(key, slice): if key.start is not None or key.stop is not None or key.step is not None: - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.edge(self._medrecord.edges) @@ -720,7 +760,7 @@ def __getitem__( ): attributes = self._medrecord._medrecord.edge(index_selection) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if isinstance(index_selection, EdgeOperation) and is_medrecord_attribute( attribute_selection @@ -729,7 +769,7 @@ def __getitem__( self._medrecord.select_edges(index_selection) ) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if isinstance(index_selection, slice) and is_medrecord_attribute( attribute_selection @@ -739,11 +779,12 @@ def __getitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge(self._medrecord.edges) - return {x: attributes[x][attribute_selection] for x in attributes.keys()} + return {x: attributes[x][attribute_selection] for x in attributes} if is_edge_index(index_selection) and isinstance(attribute_selection, list): return { @@ -758,7 +799,7 @@ def __getitem__( return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if isinstance(index_selection, EdgeOperation) and isinstance( @@ -770,7 +811,7 @@ def __getitem__( return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if isinstance(index_selection, slice) and isinstance(attribute_selection, list): @@ -779,13 +820,14 @@ def __getitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge(self._medrecord.edges) return { x: {y: attributes[x][y] for y in attribute_selection} - for x in attributes.keys() + for x in attributes } if is_edge_index(index_selection) and isinstance(attribute_selection, slice): @@ -794,7 +836,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.edge([index_selection])[index_selection] @@ -804,7 +847,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.edge(index_selection) @@ -816,7 +860,8 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.edge( self._medrecord.select_edges(index_selection) @@ -833,9 +878,11 @@ def __getitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.edge(self._medrecord.edges) + return None @overload def __setitem__( @@ -870,19 +917,22 @@ def __setitem__( ) -> None: if is_edge_index(key): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes([key], value) if isinstance(key, list): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes(key, value) if isinstance(key, EdgeOperation): if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( self._medrecord.select_edges(key), value @@ -890,10 +940,12 @@ def __setitem__( if isinstance(key, slice): if key.start is not None or key.stop is not None or key.step is not None: - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_attributes(value): - raise ValueError("Invalid value type. Expected Attributes") + msg = "Invalid value type. Expected Attributes" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( self._medrecord.edges, value @@ -905,7 +957,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_edge_attribute( [index_selection], attribute_selection, value @@ -915,7 +968,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_edge_attribute( index_selection, attribute_selection, value @@ -925,7 +979,8 @@ def __setitem__( attribute_selection ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_edge_attribute( self._medrecord.select_edges(index_selection), @@ -941,10 +996,12 @@ def __setitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) return self._medrecord._medrecord.update_edge_attribute( self._medrecord.edges, @@ -954,7 +1011,8 @@ def __setitem__( if is_edge_index(index_selection) and isinstance(attribute_selection, list): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_edge_attribute( @@ -965,7 +1023,8 @@ def __setitem__( if isinstance(index_selection, list) and isinstance(attribute_selection, list): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_edge_attribute( @@ -978,7 +1037,8 @@ def __setitem__( attribute_selection, list ): if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_edge_attribute( @@ -993,10 +1053,12 @@ def __setitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.update_edge_attribute( @@ -1011,16 +1073,18 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge([index_selection])[ index_selection ] - for attribute in attributes.keys(): + for attribute in attributes: self._medrecord._medrecord.update_edge_attribute( [index_selection], attribute, value ) @@ -1033,15 +1097,17 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge(index_selection) - for edge in attributes.keys(): - for attribute in attributes[edge].keys(): + for edge in attributes: + for attribute in attributes[edge]: self._medrecord._medrecord.update_edge_attribute( [edge], attribute, value ) @@ -1056,17 +1122,19 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge( self._medrecord.select_edges(index_selection) ) - for edge in attributes.keys(): - for attribute in attributes[edge].keys(): + for edge in attributes: + for attribute in attributes[edge]: self._medrecord._medrecord.update_edge_attribute( [edge], attribute, value ) @@ -1084,20 +1152,23 @@ def __setitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) if not is_medrecord_value(value): - raise ValueError("Invalid value type. Expected MedRecordValue") + msg = "Invalid value type. Expected MedRecordValue" + raise ValueError(msg) attributes = self._medrecord._medrecord.edge(self._medrecord.edges) - for edge in attributes.keys(): - for attribute in attributes[edge].keys(): + for edge in attributes: + for attribute in attributes[edge]: self._medrecord._medrecord.update_edge_attribute( [edge], attribute, value ) return None + return None def __delitem__( self, @@ -1138,7 +1209,8 @@ def __delitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.remove_edge_attribute( self._medrecord.edges, @@ -1177,7 +1249,8 @@ def __delitem__( or index_selection.stop is not None or index_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) for attribute in attribute_selection: self._medrecord._medrecord.remove_edge_attribute( @@ -1192,7 +1265,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( [index_selection], {} @@ -1204,7 +1278,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( index_selection, {} @@ -1218,7 +1293,8 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( self._medrecord.select_edges(index_selection), {} @@ -1235,8 +1311,10 @@ def __delitem__( or attribute_selection.stop is not None or attribute_selection.step is not None ): - raise ValueError("Invalid slice, only ':' is allowed") + msg = "Invalid slice, only ':' is allowed" + raise ValueError(msg) return self._medrecord._medrecord.replace_edge_attributes( self._medrecord.edges, {} ) + return None diff --git a/medmodels/medrecord/querying.py b/medmodels/medrecord/querying.py index 7f749578..4a18d80e 100644 --- a/medmodels/medrecord/querying.py +++ b/medmodels/medrecord/querying.py @@ -33,7 +33,7 @@ class NodeOperation: _node_operation: PyNodeOperation - def __init__(self, node_operation: PyNodeOperation): + def __init__(self, node_operation: PyNodeOperation) -> None: self._node_operation = node_operation def logical_and(self, operation: NodeOperation) -> NodeOperation: diff --git a/medmodels/medrecord/schema.py b/medmodels/medrecord/schema.py index f7887a73..bcc4e85d 100644 --- a/medmodels/medrecord/schema.py +++ b/medmodels/medrecord/schema.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum, auto -from typing import Dict, List, Optional, Tuple, Union, overload +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, overload from medmodels._medmodels import ( PyAttributeDataType, @@ -10,7 +10,9 @@ PySchema, ) from medmodels.medrecord.datatype import DataType -from medmodels.medrecord.types import Group, MedRecordAttribute + +if TYPE_CHECKING: + from medmodels.medrecord.types import Group, MedRecordAttribute class AttributeType(Enum): @@ -30,10 +32,11 @@ def _from_pyattributetype(py_attribute_type: PyAttributeType) -> AttributeType: """ if py_attribute_type == PyAttributeType.Categorical: return AttributeType.Categorical - elif py_attribute_type == PyAttributeType.Continuous: + if py_attribute_type == PyAttributeType.Continuous: return AttributeType.Continuous if py_attribute_type == PyAttributeType.Temporal: return AttributeType.Temporal + return None def _into_pyattributetype(self) -> PyAttributeType: """Converts an AttributeType to a PyAttributeType. @@ -47,7 +50,8 @@ def _into_pyattributetype(self) -> PyAttributeType: return PyAttributeType.Continuous if self == AttributeType.Temporal: return PyAttributeType.Temporal - raise NotImplementedError("Should never be reached") + msg = "Should never be reached" + raise NotImplementedError(msg) def __repr__(self) -> str: """Returns a string representation of the AttributeType instance. @@ -164,7 +168,7 @@ def __eq__(self, value: object) -> bool: Returns: bool: True if the objects are equal, False otherwise. """ - if not (isinstance(value, AttributesSchema) or isinstance(value, dict)): + if not (isinstance(value, (AttributesSchema, dict))): return False attribute_schema = ( @@ -174,7 +178,7 @@ def __eq__(self, value: object) -> bool: if not attribute_schema.keys() == self._attributes_schema.keys(): return False - for key in self._attributes_schema.keys(): + for key in self._attributes_schema: if ( not isinstance(attribute_schema[key], tuple) or not isinstance( @@ -248,12 +252,8 @@ class GroupSchema: def __init__( self, *, - nodes: Dict[ - MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]] - ] = {}, - edges: Dict[ - MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]] - ] = {}, + nodes: Optional[Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]]] = None, + edges: Optional[Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]]] = None, strict: bool = False, ) -> None: """Initializes a new instance of GroupSchema. @@ -271,6 +271,10 @@ def __init__( Returns: None """ + if edges is None: + edges = {} + if nodes is None: + nodes = {} def _convert_input( input: Union[DataType, Tuple[DataType, AttributeType]], @@ -369,7 +373,7 @@ class Schema: def __init__( self, *, - groups: Dict[Group, GroupSchema] = {}, + groups: Optional[Dict[Group, GroupSchema]] = None, default: Optional[GroupSchema] = None, strict: bool = False, ) -> None: @@ -386,6 +390,8 @@ def __init__( Returns: None """ + if groups is None: + groups = {} if default is not None: self._schema = PySchema( groups={x: groups[x]._group_schema for x in groups}, diff --git a/medmodels/medrecord/tests/test_builder.py b/medmodels/medrecord/tests/test_builder.py index 73d01e80..3d5a5a67 100644 --- a/medmodels/medrecord/tests/test_builder.py +++ b/medmodels/medrecord/tests/test_builder.py @@ -1,64 +1,66 @@ import unittest +import pytest + import medmodels.medrecord as mr class TestMedRecordBuilder(unittest.TestCase): - def test_add_nodes(self): + def test_add_nodes(self) -> None: builder = mr.MedRecord.builder().add_nodes([("node1", {})]) - self.assertEqual(len(builder._nodes), 1) + assert len(builder._nodes) == 1 builder.add_nodes([("node2", {})], group="group") - self.assertEqual(len(builder._nodes), 2) + assert len(builder._nodes) == 2 medrecord = builder.build() - self.assertEqual(2, len(medrecord.nodes)) - self.assertEqual(1, len(medrecord.groups)) - self.assertEqual(["group"], medrecord.groups_of_node("node2")) + assert len(medrecord.nodes) == 2 + assert len(medrecord.groups) == 1 + assert medrecord.groups_of_node("node2") == ["group"] - def test_add_edges(self): + def test_add_edges(self) -> None: builder = ( mr.MedRecord.builder() .add_nodes([("node1", {}), ("node2", {})]) .add_edges([("node1", "node2", {})]) ) - self.assertEqual(len(builder._edges), 1) + assert len(builder._edges) == 1 builder.add_edges([("node2", "node1", {})], group="group") medrecord = builder.build() - self.assertEqual(2, len(medrecord.nodes)) - self.assertEqual(2, len(medrecord.edges)) - self.assertEqual(1, len(medrecord.groups)) - self.assertEqual(["node2"], medrecord.neighbors("node1")) - self.assertEqual(["group"], medrecord.groups_of_edge(1)) + assert len(medrecord.nodes) == 2 + assert len(medrecord.edges) == 2 + assert len(medrecord.groups) == 1 + assert medrecord.neighbors("node1") == ["node2"] + assert medrecord.groups_of_edge(1) == ["group"] - def test_add_group(self): + def test_add_group(self) -> None: builder = ( mr.MedRecord.builder().add_nodes(("0", {})).add_group("group", nodes=["0"]) ) - self.assertEqual(len(builder._groups), 1) + assert len(builder._groups) == 1 medrecord = builder.build() - self.assertEqual(1, len(medrecord.nodes)) - self.assertEqual(0, len(medrecord.edges)) - self.assertEqual(1, len(medrecord.groups)) - self.assertEqual("group", medrecord.groups[0]) - self.assertEqual(["0"], medrecord.nodes_in_group("group")) + assert len(medrecord.nodes) == 1 + assert len(medrecord.edges) == 0 + assert len(medrecord.groups) == 1 + assert medrecord.groups[0] == "group" + assert medrecord.nodes_in_group("group") == ["0"] - def test_with_schema(self): + def test_with_schema(self) -> None: schema = mr.Schema(default=mr.GroupSchema(nodes={"attribute": mr.Int()})) medrecord = mr.MedRecord.builder().with_schema(schema).build() medrecord.add_node("node", {"attribute": 1}) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.add_node("node", {"attribute": "1"}) diff --git a/medmodels/medrecord/tests/test_datatype.py b/medmodels/medrecord/tests/test_datatype.py index 0cf16cd5..331519bb 100644 --- a/medmodels/medrecord/tests/test_datatype.py +++ b/medmodels/medrecord/tests/test_datatype.py @@ -1,5 +1,7 @@ import unittest +import pytest + import medmodels.medrecord as mr from medmodels._medmodels import ( PyAny, @@ -15,121 +17,112 @@ class TestDataType(unittest.TestCase): - def test_string(self): + def test_string(self) -> None: string = mr.String() - self.assertTrue(isinstance(string._inner(), PyString)) + assert isinstance(string._inner(), PyString) - self.assertEqual("String", str(string)) + assert str(string) == "String" - self.assertEqual("DataType.String", string.__repr__()) + assert string.__repr__() == "DataType.String" - self.assertEqual(mr.String(), mr.String()) - self.assertNotEqual(mr.String(), mr.Int()) + assert mr.String() == mr.String() + assert mr.String() != mr.Int() - def test_int(self): + def test_int(self) -> None: integer = mr.Int() - self.assertTrue(isinstance(integer._inner(), PyInt)) + assert isinstance(integer._inner(), PyInt) - self.assertEqual("Int", str(integer)) + assert str(integer) == "Int" - self.assertEqual("DataType.Int", integer.__repr__()) + assert integer.__repr__() == "DataType.Int" - self.assertEqual(mr.Int(), mr.Int()) - self.assertNotEqual(mr.Int(), mr.String()) + assert mr.Int() == mr.Int() + assert mr.Int() != mr.String() - def test_float(self): + def test_float(self) -> None: float = mr.Float() - self.assertTrue(isinstance(float._inner(), PyFloat)) + assert isinstance(float._inner(), PyFloat) - self.assertEqual("Float", str(float)) + assert str(float) == "Float" - self.assertEqual("DataType.Float", float.__repr__()) + assert float.__repr__() == "DataType.Float" - self.assertEqual(mr.Float(), mr.Float()) - self.assertNotEqual(mr.Float(), mr.String()) + assert mr.Float() == mr.Float() + assert mr.Float() != mr.String() - def test_bool(self): + def test_bool(self) -> None: bool = mr.Bool() - self.assertTrue(isinstance(bool._inner(), PyBool)) + assert isinstance(bool._inner(), PyBool) - self.assertEqual("Bool", str(bool)) + assert str(bool) == "Bool" - self.assertEqual("DataType.Bool", bool.__repr__()) + assert bool.__repr__() == "DataType.Bool" - self.assertEqual(mr.Bool(), mr.Bool()) - self.assertNotEqual(mr.Bool(), mr.String()) + assert mr.Bool() == mr.Bool() + assert mr.Bool() != mr.String() - def test_datetime(self): + def test_datetime(self) -> None: datetime = mr.DateTime() - self.assertTrue(isinstance(datetime._inner(), PyDateTime)) + assert isinstance(datetime._inner(), PyDateTime) - self.assertEqual("DateTime", str(datetime)) + assert str(datetime) == "DateTime" - self.assertEqual("DataType.DateTime", datetime.__repr__()) + assert datetime.__repr__() == "DataType.DateTime" - self.assertEqual(mr.DateTime(), mr.DateTime()) - self.assertNotEqual(mr.DateTime(), mr.String()) + assert mr.DateTime() == mr.DateTime() + assert mr.DateTime() != mr.String() - def test_null(self): + def test_null(self) -> None: null = mr.Null() - self.assertTrue(isinstance(null._inner(), PyNull)) + assert isinstance(null._inner(), PyNull) - self.assertEqual("Null", str(null)) + assert str(null) == "Null" - self.assertEqual("DataType.Null", null.__repr__()) + assert null.__repr__() == "DataType.Null" - self.assertEqual(mr.Null(), mr.Null()) - self.assertNotEqual(mr.Null(), mr.String()) + assert mr.Null() == mr.Null() + assert mr.Null() != mr.String() - def test_any(self): + def test_any(self) -> None: any = mr.Any() - self.assertTrue(isinstance(any._inner(), PyAny)) + assert isinstance(any._inner(), PyAny) - self.assertEqual("Any", str(any)) + assert str(any) == "Any" - self.assertEqual("DataType.Any", any.__repr__()) + assert any.__repr__() == "DataType.Any" - self.assertEqual(mr.Any(), mr.Any()) - self.assertNotEqual(mr.Any(), mr.String()) + assert mr.Any() == mr.Any() + assert mr.Any() != mr.String() - def test_union(self): + def test_union(self) -> None: union = mr.Union(mr.String(), mr.Int()) - self.assertTrue(isinstance(union._inner(), PyUnion)) + assert isinstance(union._inner(), PyUnion) - self.assertEqual("Union(String, Int)", str(union)) + assert str(union) == "Union(String, Int)" - self.assertEqual( - "DataType.Union(DataType.String, DataType.Int)", union.__repr__() - ) + assert union.__repr__() == "DataType.Union(DataType.String, DataType.Int)" union = mr.Union(mr.String(), mr.Int(), mr.Bool()) - self.assertTrue(isinstance(union._inner(), PyUnion)) + assert isinstance(union._inner(), PyUnion) - self.assertEqual("Union(String, Union(Int, Bool))", str(union)) + assert str(union) == "Union(String, Union(Int, Bool))" - self.assertEqual( - "DataType.Union(DataType.String, DataType.Union(DataType.Int, DataType.Bool))", - union.__repr__(), - ) + assert union.__repr__() == "DataType.Union(DataType.String, DataType.Union(DataType.Int, DataType.Bool))" - self.assertEqual( - mr.Union(mr.String(), mr.Int()), mr.Union(mr.String(), mr.Int()) - ) - self.assertNotEqual( - mr.Union(mr.String(), mr.Int()), mr.Union(mr.Int(), mr.String()) - ) + assert mr.Union(mr.String(), mr.Int()) == mr.Union(mr.String(), mr.Int()) + assert mr.Union(mr.String(), mr.Int()) != mr.Union(mr.Int(), mr.String()) - def test_invalid_union(self): - with self.assertRaises(ValueError): + def test_invalid_union(self) -> None: + with pytest.raises(ValueError): mr.Union(mr.String()) - def test_option(self): + def test_option(self) -> None: option = mr.Option(mr.String()) - self.assertTrue(isinstance(option._inner(), PyOption)) + assert isinstance(option._inner(), PyOption) - self.assertEqual("Option(String)", str(option)) + assert str(option) == "Option(String)" - self.assertEqual("DataType.Option(DataType.String)", option.__repr__()) + assert option.__repr__() == "DataType.Option(DataType.String)" - self.assertEqual(mr.Option(mr.String()), mr.Option(mr.String())) - self.assertNotEqual(mr.Option(mr.String()), mr.Option(mr.Int())) + assert mr.Option(mr.String()) == mr.Option(mr.String()) + assert mr.Option(mr.String()) != mr.Option(mr.Int()) diff --git a/medmodels/medrecord/tests/test_indexers.py b/medmodels/medrecord/tests/test_indexers.py index 17492c6d..b8879e82 100644 --- a/medmodels/medrecord/tests/test_indexers.py +++ b/medmodels/medrecord/tests/test_indexers.py @@ -1,5 +1,7 @@ import unittest +import pytest + from medmodels import MedRecord from medmodels.medrecord import edge, node @@ -22,354 +24,207 @@ def create_medrecord(): class TestMedRecord(unittest.TestCase): - def test_node_getitem(self): + def test_node_getitem(self) -> None: medrecord = create_medrecord() - self.assertEqual( - {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, medrecord.node[0] - ) + assert medrecord.node[0] == {"foo": "bar", "bar": "foo", "lorem": "ipsum"} # Accessing a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.node[50] - self.assertEqual("bar", medrecord.node[0, "foo"]) + assert medrecord.node[0, "foo"] == "bar" # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[0, "test"] - self.assertEqual( - {"foo": "bar", "bar": "foo"}, medrecord.node[0, ["foo", "bar"]] - ) + assert medrecord.node[0, ["foo", "bar"]] == {"foo": "bar", "bar": "foo"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[0, ["foo", "test"]] - self.assertEqual( - {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, medrecord.node[0, :] - ) + assert medrecord.node[0, :] == {"foo": "bar", "bar": "foo", "lorem": "ipsum"} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[0, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[0, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[0, ::1] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.node[[0, 1]], - ) + assert medrecord.node[[0, 1]] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}} - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.node[[0, 50]] - self.assertEqual( - { - 0: "bar", - 1: "bar", - }, - medrecord.node[[0, 1], "foo"], - ) + assert medrecord.node[[0, 1], "foo"] == {0: "bar", 1: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[[0, 1], "test"] # Accessing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[[0, 1], "lorem"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.node[[0, 1], ["foo", "bar"]], - ) + assert medrecord.node[[0, 1], ["foo", "bar"]] == {0: {"foo": "bar", "bar": "foo"}, 1: {"foo": "bar", "bar": "foo"}} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[[0, 1], ["foo", "test"]] # Accessing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[[0, 1], ["foo", "lorem"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.node[[0, 1], :], - ) + assert medrecord.node[[0, 1], :] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[[0, 1], 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[[0, 1], :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[[0, 1], ::1] - self.assertEqual( - {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}}, - medrecord.node[node().index() >= 2], - ) + assert medrecord.node[node().index() >= 2] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} # Empty query should not fail - self.assertEqual( - {}, - medrecord.node[node().index() > 3], - ) + assert medrecord.node[node().index() > 3] == {} - self.assertEqual( - {2: "bar", 3: "bar"}, - medrecord.node[node().index() >= 2, "foo"], - ) + assert medrecord.node[node().index() >= 2, "foo"] == {2: "bar", 3: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[node().index() >= 2, "test"] - self.assertEqual( - { - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[node().index() >= 2, ["foo", "bar"]], - ) + assert medrecord.node[node().index() >= 2, ["foo", "bar"]] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[node().index() >= 2, ["foo", "test"]] # Accessing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[node().index() < 2, ["foo", "lorem"]] - self.assertEqual( - { - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[node().index() >= 2, :], - ) + assert medrecord.node[node().index() >= 2, :] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, ::1] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.node[1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1] - self.assertEqual( - { - 0: "bar", - 1: "bar", - 2: "bar", - 3: "bar", - }, - medrecord.node[:, "foo"], - ) + assert medrecord.node[:, "foo"] == {0: "bar", 1: "bar", 2: "bar", 3: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[:, "test"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[1:, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:, ["foo", "bar"]], - ) + assert medrecord.node[:, ["foo", "bar"]] == {0: {"foo": "bar", "bar": "foo"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[:, ["foo", "test"]] # Accessing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.node[:, ["foo", "lorem"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[1:, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, ["foo", "bar"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:, :], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:, :] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.node[1:, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, ::1] - def test_node_setitem(self): + def test_node_setitem(self) -> None: # Updating existing attributes medrecord = create_medrecord() medrecord.node[0] = {"foo": "bar", "bar": "test"} - self.assertEqual( - { - 0: {"foo": "bar", "bar": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Updating a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.node[50] = {"foo": "bar", "test": "test"} medrecord = create_medrecord() medrecord.node[0, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[0, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[0, :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.node[0, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[0, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[0, ::1] = "test" medrecord = create_medrecord() medrecord.node[[0, 1], "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[[0, 1], ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[[0, 1], :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.node[[0, 1], 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[[0, 1], :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[[0, 1], ::1] = "test" medrecord = create_medrecord() medrecord.node[node().index() >= 2] = {"foo": "bar", "bar": "test"} - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "test"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "test"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Empty query should not fail @@ -377,929 +232,490 @@ def test_node_setitem(self): medrecord = create_medrecord() medrecord.node[node().index() >= 2, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "foo"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "foo"}, 3: {"foo": "test", "bar": "test"}} medrecord = create_medrecord() medrecord.node[node().index() >= 2, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} medrecord = create_medrecord() medrecord.node[node().index() >= 2, :] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[node().index() >= 2, ::1] = "test" medrecord = create_medrecord() medrecord.node[:, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "foo"}, - 2: {"foo": "test", "bar": "foo"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "foo"}, 2: {"foo": "test", "bar": "foo"}, 3: {"foo": "test", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.node[1:, "foo"] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, "foo"] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, "foo"] = "test" medrecord = create_medrecord() medrecord.node[:, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.node[1:, ["foo", "bar"]] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, ["foo", "bar"]] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, ["foo", "bar"]] = "test" medrecord = create_medrecord() medrecord.node[:, :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.node[1:, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:1, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[::1, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.node[:, ::1] = "test" # Adding new attributes medrecord = create_medrecord() medrecord.node[0, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[0, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[[0, 1], "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[[0, 1], ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[node().index() >= 2, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo", "test": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test"}} medrecord = create_medrecord() medrecord.node[node().index() >= 2, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: { - "foo": "bar", - "bar": "foo", - "test": "test", - "test2": "test", - }, - 3: { - "foo": "bar", - "bar": "test", - "test": "test", - "test2": "test", - }, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}} medrecord = create_medrecord() medrecord.node[:, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "test": "test"}, - 2: {"foo": "bar", "bar": "foo", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test"}, 2: {"foo": "bar", "bar": "foo", "test": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test"}} medrecord = create_medrecord() medrecord.node[:, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}} # Adding and updating attributes medrecord = create_medrecord() medrecord.node[[0, 1], "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[[0, 1], ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[node().index() < 2, "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[node().index() < 2, ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.node[:, "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 3: {"foo": "bar", "bar": "test", "lorem": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, 3: {"foo": "bar", "bar": "test", "lorem": "test"}} medrecord = create_medrecord() medrecord.node[:, ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}, - }, - medrecord.node[:], - ) - - def test_node_delitem(self): + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}} + + def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[0, "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Removing from a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): del medrecord.node[50, "foo"] medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[0, "test"] medrecord = create_medrecord() del medrecord.node[0, ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[0, ["foo", "test"]] medrecord = create_medrecord() del medrecord.node[0, :] - self.assertEqual( - { - 0: {}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == {0: {}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): del medrecord.node[0, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[0, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[0, ::1] medrecord = create_medrecord() del medrecord.node[[0, 1], "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Removing from a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): del medrecord.node[[0, 50], "foo"] medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[[0, 1], "test"] medrecord = create_medrecord() del medrecord.node[[0, 1], ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"lorem": "ipsum"}, 1: {}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[[0, 1], ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[[0, 1], ["foo", "lorem"]] medrecord = create_medrecord() del medrecord.node[[0, 1], :] - self.assertEqual( - { - 0: {}, - 1: {}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == {0: {}, 1: {}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): del medrecord.node[[0, 1], 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[[0, 1], :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[[0, 1], ::1] medrecord = create_medrecord() del medrecord.node[node().index() >= 2, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"bar": "foo"}, - 3: {"bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"bar": "foo"}, 3: {"bar": "test"}} medrecord = create_medrecord() # Empty query should not fail del medrecord.node[node().index() > 3, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[node().index() >= 2, "test"] medrecord = create_medrecord() del medrecord.node[node().index() >= 2, ["foo", "bar"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {}, - 3: {}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {}, 3: {}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[node().index() >= 2, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[node().index() < 2, ["foo", "lorem"]] medrecord = create_medrecord() del medrecord.node[node().index() >= 2, :] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {}, - 3: {}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {}, 3: {}} + + with pytest.raises(ValueError): del medrecord.node[node().index() >= 2, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[node().index() >= 2, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[node().index() >= 2, ::1] medrecord = create_medrecord() del medrecord.node[:, "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"bar": "foo"}, - 2: {"bar": "foo"}, - 3: {"bar": "test"}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"bar": "foo"}, 2: {"bar": "foo"}, 3: {"bar": "test"}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[:, "test"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[1:, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:1, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[::1, "foo"] medrecord = create_medrecord() del medrecord.node[:, ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {}, - 2: {}, - 3: {}, - }, - medrecord.node[:], - ) + assert medrecord.node[:] == {0: {"lorem": "ipsum"}, 1: {}, 2: {}, 3: {}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[:, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all nodes should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.node[:, ["foo", "lorem"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[1:, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:1, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[::1, ["foo", "bar"]] medrecord = create_medrecord() del medrecord.node[:, :] - self.assertEqual( - { - 0: {}, - 1: {}, - 2: {}, - 3: {}, - }, - medrecord.node[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.node[:] == {0: {}, 1: {}, 2: {}, 3: {}} + + with pytest.raises(ValueError): del medrecord.node[1:, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[::1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.node[:, ::1] - def test_edge_getitem(self): + def test_edge_getitem(self) -> None: medrecord = create_medrecord() - self.assertEqual( - {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, medrecord.edge[0] - ) + assert medrecord.edge[0] == {"foo": "bar", "bar": "foo", "lorem": "ipsum"} # Accessing a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edge[50] - self.assertEqual("bar", medrecord.edge[0, "foo"]) + assert medrecord.edge[0, "foo"] == "bar" # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[0, "test"] - self.assertEqual( - {"foo": "bar", "bar": "foo"}, medrecord.edge[0, ["foo", "bar"]] - ) + assert medrecord.edge[0, ["foo", "bar"]] == {"foo": "bar", "bar": "foo"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[0, ["foo", "test"]] - self.assertEqual( - {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, medrecord.edge[0, :] - ) + assert medrecord.edge[0, :] == {"foo": "bar", "bar": "foo", "lorem": "ipsum"} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[0, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[0, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[0, ::1] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.edge[[0, 1]], - ) + assert medrecord.edge[[0, 1]] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}} - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edge[[0, 50]] - self.assertEqual( - { - 0: "bar", - 1: "bar", - }, - medrecord.edge[[0, 1], "foo"], - ) + assert medrecord.edge[[0, 1], "foo"] == {0: "bar", 1: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[[0, 1], "test"] # Accessing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[[0, 1], "lorem"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.edge[[0, 1], ["foo", "bar"]], - ) + assert medrecord.edge[[0, 1], ["foo", "bar"]] == {0: {"foo": "bar", "bar": "foo"}, 1: {"foo": "bar", "bar": "foo"}} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[[0, 1], ["foo", "test"]] # Accessing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[[0, 1], ["foo", "lorem"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - }, - medrecord.edge[[0, 1], :], - ) + assert medrecord.edge[[0, 1], :] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[[0, 1], 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[[0, 1], :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[[0, 1], ::1] - self.assertEqual( - {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}}, - medrecord.edge[edge().index() >= 2], - ) + assert medrecord.edge[edge().index() >= 2] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} # Empty query should not fail - self.assertEqual( - {}, - medrecord.edge[edge().index() > 3], - ) + assert medrecord.edge[edge().index() > 3] == {} - self.assertEqual( - {2: "bar", 3: "bar"}, - medrecord.edge[edge().index() >= 2, "foo"], - ) + assert medrecord.edge[edge().index() >= 2, "foo"] == {2: "bar", 3: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[edge().index() >= 2, "test"] - self.assertEqual( - { - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[edge().index() >= 2, ["foo", "bar"]], - ) + assert medrecord.edge[edge().index() >= 2, ["foo", "bar"]] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[edge().index() >= 2, ["foo", "test"]] # Accessing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[edge().index() < 2, ["foo", "lorem"]] - self.assertEqual( - { - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[edge().index() >= 2, :], - ) + assert medrecord.edge[edge().index() >= 2, :] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, ::1] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.edge[1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1] - self.assertEqual( - { - 0: "bar", - 1: "bar", - 2: "bar", - 3: "bar", - }, - medrecord.edge[:, "foo"], - ) + assert medrecord.edge[:, "foo"] == {0: "bar", 1: "bar", 2: "bar", 3: "bar"} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[:, "test"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[1:, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:, ["foo", "bar"]], - ) + assert medrecord.edge[:, ["foo", "bar"]] == {0: {"foo": "bar", "bar": "foo"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} # Accessing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[:, ["foo", "test"]] # Accessing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): medrecord.edge[:, ["foo", "lorem"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[1:, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, ["foo", "bar"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:, :], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:, :] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.edge[1:, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, ::1] - def test_edge_setitem(self): + def test_edge_setitem(self) -> None: # Updating existing attributes medrecord = create_medrecord() medrecord.edge[0] = {"foo": "bar", "bar": "test"} - self.assertEqual( - { - 0: {"foo": "bar", "bar": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Updating a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edge[50] = {"foo": "bar", "test": "test"} medrecord = create_medrecord() medrecord.edge[0, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[0, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[0, :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.edge[0, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[0, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[0, ::1] = "test" medrecord = create_medrecord() medrecord.edge[[0, 1], "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[[0, 1], ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[[0, 1], :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.edge[[0, 1], 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[[0, 1], :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[[0, 1], ::1] = "test" medrecord = create_medrecord() medrecord.edge[edge().index() >= 2] = {"foo": "bar", "bar": "test"} - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "test"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "test"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Empty query should not fail @@ -1307,577 +723,285 @@ def test_edge_setitem(self): medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "foo"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "foo"}, 3: {"foo": "test", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, :] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, ::1] = "test" medrecord = create_medrecord() medrecord.edge[:, "foo"] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "foo"}, - 2: {"foo": "test", "bar": "foo"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "foo"}, 2: {"foo": "test", "bar": "foo"}, 3: {"foo": "test", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.edge[1:, "foo"] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, "foo"] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, "foo"] = "test" medrecord = create_medrecord() medrecord.edge[:, ["foo", "bar"]] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.edge[1:, ["foo", "bar"]] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, ["foo", "bar"]] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, ["foo", "bar"]] = "test" medrecord = create_medrecord() medrecord.edge[:, :] = "test" - self.assertEqual( - { - 0: {"foo": "test", "bar": "test", "lorem": "test"}, - 1: {"foo": "test", "bar": "test"}, - 2: {"foo": "test", "bar": "test"}, - 3: {"foo": "test", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + + with pytest.raises(ValueError): medrecord.edge[1:, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:1, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[::1, :] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, 1:] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, :1] = "test" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.edge[:, ::1] = "test" # Adding new attributes medrecord = create_medrecord() medrecord.edge[0, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[0, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[[0, 1], "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[[0, 1], ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo", "test": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test"}} medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: { - "foo": "bar", - "bar": "foo", - "test": "test", - "test2": "test", - }, - 3: { - "foo": "bar", - "bar": "test", - "test": "test", - "test2": "test", - }, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}} medrecord = create_medrecord() medrecord.edge[:, "test"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "test": "test"}, - 2: {"foo": "bar", "bar": "foo", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test"}, 2: {"foo": "bar", "bar": "foo", "test": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test"}} medrecord = create_medrecord() medrecord.edge[:, ["test", "test2"]] = "test" - self.assertEqual( - { - 0: { - "foo": "bar", - "bar": "foo", - "lorem": "ipsum", - "test": "test", - "test2": "test", - }, - 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, - 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}} # Adding and updating attributes medrecord = create_medrecord() medrecord.edge[[0, 1], "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[[0, 1], ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[edge().index() < 2, "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[edge().index() < 2, ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() medrecord.edge[:, "lorem"] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, - 3: {"foo": "bar", "bar": "test", "lorem": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, 3: {"foo": "bar", "bar": "test", "lorem": "test"}} medrecord = create_medrecord() medrecord.edge[:, ["lorem", "test"]] = "test" - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, - 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}, - }, - medrecord.edge[:], - ) - - def test_edge_delitem(self): + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}} + + def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[0, "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Removing from a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): del medrecord.edge[50, "foo"] medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[0, "test"] medrecord = create_medrecord() del medrecord.edge[0, ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[0, ["foo", "test"]] medrecord = create_medrecord() del medrecord.edge[0, :] - self.assertEqual( - { - 0: {}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == {0: {}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): del medrecord.edge[0, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[0, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[0, ::1] medrecord = create_medrecord() del medrecord.edge[[0, 1], "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Removing from a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): del medrecord.edge[[0, 50], "foo"] medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[[0, 1], "test"] medrecord = create_medrecord() del medrecord.edge[[0, 1], ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"lorem": "ipsum"}, 1: {}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[[0, 1], ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[[0, 1], ["foo", "lorem"]] medrecord = create_medrecord() del medrecord.edge[[0, 1], :] - self.assertEqual( - { - 0: {}, - 1: {}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == {0: {}, 1: {}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + + with pytest.raises(ValueError): del medrecord.edge[[0, 1], 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[[0, 1], :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[[0, 1], ::1] medrecord = create_medrecord() del medrecord.edge[edge().index() >= 2, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"bar": "foo"}, - 3: {"bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"bar": "foo"}, 3: {"bar": "test"}} medrecord = create_medrecord() # Empty query should not fail del medrecord.edge[edge().index() > 3, "foo"] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {"foo": "bar", "bar": "foo"}, - 3: {"foo": "bar", "bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[edge().index() >= 2, "test"] medrecord = create_medrecord() del medrecord.edge[edge().index() >= 2, ["foo", "bar"]] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {}, - 3: {}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {}, 3: {}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[edge().index() >= 2, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[edge().index() < 2, ["foo", "lorem"]] medrecord = create_medrecord() del medrecord.edge[edge().index() >= 2, :] - self.assertEqual( - { - 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, - 1: {"foo": "bar", "bar": "foo"}, - 2: {}, - 3: {}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {}, 3: {}} + + with pytest.raises(ValueError): del medrecord.edge[edge().index() >= 2, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[edge().index() >= 2, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[edge().index() >= 2, ::1] medrecord = create_medrecord() del medrecord.edge[:, "foo"] - self.assertEqual( - { - 0: {"bar": "foo", "lorem": "ipsum"}, - 1: {"bar": "foo"}, - 2: {"bar": "foo"}, - 3: {"bar": "test"}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"bar": "foo"}, 2: {"bar": "foo"}, 3: {"bar": "test"}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[:, "test"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[1:, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:1, "foo"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[::1, "foo"] medrecord = create_medrecord() del medrecord.edge[:, ["foo", "bar"]] - self.assertEqual( - { - 0: {"lorem": "ipsum"}, - 1: {}, - 2: {}, - 3: {}, - }, - medrecord.edge[:], - ) + assert medrecord.edge[:] == {0: {"lorem": "ipsum"}, 1: {}, 2: {}, 3: {}} medrecord = create_medrecord() # Removing a non-existing key should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[:, ["foo", "test"]] medrecord = create_medrecord() # Removing a key that doesn't exist in all edges should fail - with self.assertRaises(KeyError): + with pytest.raises(KeyError): del medrecord.edge[:, ["foo", "lorem"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[1:, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:1, ["foo", "bar"]] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[::1, ["foo", "bar"]] medrecord = create_medrecord() del medrecord.edge[:, :] - self.assertEqual( - { - 0: {}, - 1: {}, - 2: {}, - 3: {}, - }, - medrecord.edge[:], - ) - - with self.assertRaises(ValueError): + assert medrecord.edge[:] == {0: {}, 1: {}, 2: {}, 3: {}} + + with pytest.raises(ValueError): del medrecord.edge[1:, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[::1, :] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:, 1:] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:, :1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): del medrecord.edge[:, ::1] diff --git a/medmodels/medrecord/tests/test_medrecord.py b/medmodels/medrecord/tests/test_medrecord.py index 555bbd04..a0619852 100644 --- a/medmodels/medrecord/tests/test_medrecord.py +++ b/medmodels/medrecord/tests/test_medrecord.py @@ -4,6 +4,7 @@ import pandas as pd import polars as pl +import pytest import medmodels.medrecord as mr from medmodels import MedRecord @@ -73,30 +74,30 @@ def create_medrecord() -> MedRecord: class TestMedRecord(unittest.TestCase): - def test_from_tuples(self): + def test_from_tuples(self) -> None: medrecord = create_medrecord() - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 4 - def test_invalid_from_tuples(self): + def test_invalid_from_tuples(self) -> None: nodes = create_nodes() # Adding an edge pointing to a non-existent node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): MedRecord.from_tuples(nodes, [("0", "50", {})]) # Adding an edge from a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): MedRecord.from_tuples(nodes, [("50", "0", {})]) - def test_from_pandas(self): + def test_from_pandas(self) -> None: medrecord = MedRecord.from_pandas( (create_pandas_nodes_dataframe(), "index"), ) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.node_count() == 2 + assert medrecord.edge_count() == 0 medrecord = MedRecord.from_pandas( [ @@ -105,16 +106,16 @@ def test_from_pandas(self): ], ) - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 0 medrecord = MedRecord.from_pandas( (create_pandas_nodes_dataframe(), "index"), (create_pandas_edges_dataframe(), "source", "target"), ) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.node_count() == 2 + assert medrecord.edge_count() == 2 medrecord = MedRecord.from_pandas( [ @@ -124,8 +125,8 @@ def test_from_pandas(self): (create_pandas_edges_dataframe(), "source", "target"), ) - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 2 medrecord = MedRecord.from_pandas( (create_pandas_nodes_dataframe(), "index"), @@ -135,8 +136,8 @@ def test_from_pandas(self): ], ) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.node_count() == 2 + assert medrecord.edge_count() == 4 medrecord = MedRecord.from_pandas( [ @@ -149,10 +150,10 @@ def test_from_pandas(self): ], ) - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 4 - def test_from_polars(self): + def test_from_polars(self) -> None: nodes = pl.from_pandas(create_pandas_nodes_dataframe()) second_nodes = pl.from_pandas(create_second_pandas_nodes_dataframe()) edges = pl.from_pandas(create_pandas_edges_dataframe()) @@ -160,83 +161,83 @@ def test_from_polars(self): medrecord = MedRecord.from_polars((nodes, "index"), (edges, "source", "target")) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.node_count() == 2 + assert medrecord.edge_count() == 2 medrecord = MedRecord.from_polars( [(nodes, "index"), (second_nodes, "index")], (edges, "source", "target") ) - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 2 medrecord = MedRecord.from_polars( (nodes, "index"), [(edges, "source", "target"), (second_edges, "source", "target")], ) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.node_count() == 2 + assert medrecord.edge_count() == 4 medrecord = MedRecord.from_polars( [(nodes, "index"), (second_nodes, "index")], [(edges, "source", "target"), (second_edges, "source", "target")], ) - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 4 - def test_invalid_from_polars(self): + def test_invalid_from_polars(self) -> None: nodes = pl.from_pandas(create_pandas_nodes_dataframe()) second_nodes = pl.from_pandas(create_second_pandas_nodes_dataframe()) edges = pl.from_pandas(create_pandas_edges_dataframe()) second_edges = pl.from_pandas(create_second_pandas_edges_dataframe()) # Providing the wrong node index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars((nodes, "invalid"), (edges, "source", "target")) # Providing the wrong node index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars( [(nodes, "index"), (second_nodes, "invalid")], (edges, "source", "target"), ) # Providing the wrong source index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars((nodes, "index"), (edges, "invalid", "target")) # Providing the wrong source index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars( (nodes, "index"), [(edges, "source", "target"), (second_edges, "invalid", "target")], ) # Providing the wrong target index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars((nodes, "index"), (edges, "source", "invalid")) # Providing the wrong target index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): MedRecord.from_polars( (nodes, "index"), [(edges, "source", "target"), (edges, "source", "invalid")], ) - def test_from_example_dataset(self): + def test_from_example_dataset(self) -> None: medrecord = MedRecord.from_example_dataset() - self.assertEqual(73, medrecord.node_count()) - self.assertEqual(160, medrecord.edge_count()) + assert medrecord.node_count() == 73 + assert medrecord.edge_count() == 160 - self.assertEqual(25, len(medrecord.nodes_in_group("diagnosis"))) - self.assertEqual(19, len(medrecord.nodes_in_group("drug"))) - self.assertEqual(5, len(medrecord.nodes_in_group("patient"))) - self.assertEqual(24, len(medrecord.nodes_in_group("procedure"))) + assert len(medrecord.nodes_in_group("diagnosis")) == 25 + assert len(medrecord.nodes_in_group("drug")) == 19 + assert len(medrecord.nodes_in_group("patient")) == 5 + assert len(medrecord.nodes_in_group("procedure")) == 24 - def test_ron(self): + def test_ron(self) -> None: medrecord = create_medrecord() with tempfile.NamedTemporaryFile() as f: @@ -244,10 +245,10 @@ def test_ron(self): loaded_medrecord = MedRecord.from_ron(f.name) - self.assertEqual(medrecord.node_count(), loaded_medrecord.node_count()) - self.assertEqual(medrecord.edge_count(), loaded_medrecord.edge_count()) + assert medrecord.node_count() == loaded_medrecord.node_count() + assert medrecord.edge_count() == loaded_medrecord.edge_count() - def test_schema(self): + def test_schema(self) -> None: schema = mr.Schema( groups={ "group": mr.GroupSchema( @@ -264,7 +265,7 @@ def test_schema(self): medrecord.add_node("0", {"attribute": 1}) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.add_node("1", {"attribute": "1"}) medrecord.add_node("1", {"attribute": 1, "attribute2": 1}) @@ -273,12 +274,12 @@ def test_schema(self): medrecord.add_node("2", {"attribute": 1, "attribute2": "1"}) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.add_node_to_group("group", "2") medrecord.add_edge("0", "1", {"attribute": 1}) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.add_edge("0", "1", {"attribute": "1"}) edge_index = medrecord.add_edge("0", "1", {"attribute": 1, "attribute2": 1}) @@ -287,273 +288,257 @@ def test_schema(self): edge_index = medrecord.add_edge("0", "1", {"attribute": 1, "attribute2": "1"}) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): medrecord.add_edge_to_group("group", edge_index) - def test_nodes(self): + def test_nodes(self) -> None: medrecord = create_medrecord() nodes = [x[0] for x in create_nodes()] for node in medrecord.nodes: - self.assertTrue(node in nodes) + assert node in nodes - def test_edges(self): + def test_edges(self) -> None: medrecord = create_medrecord() edges = list(range(len(create_edges()))) for edge in medrecord.edges: - self.assertTrue(edge in edges) + assert edge in edges - def test_groups(self): + def test_groups(self) -> None: medrecord = create_medrecord() medrecord.add_group("0") - self.assertEqual(["0"], medrecord.groups) + assert medrecord.groups == ["0"] - def test_group(self): + def test_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0") - self.assertEqual({"nodes": [], "edges": []}, medrecord.group("0")) + assert medrecord.group("0") == {"nodes": [], "edges": []} medrecord.add_group("1", ["0"], [0]) - self.assertEqual({"nodes": ["0"], "edges": [0]}, medrecord.group("1")) + assert medrecord.group("1") == {"nodes": ["0"], "edges": [0]} - self.assertEqual( - {"0": {"nodes": [], "edges": []}, "1": {"nodes": ["0"], "edges": [0]}}, - medrecord.group(["0", "1"]), - ) + assert medrecord.group(["0", "1"]) == {"0": {"nodes": [], "edges": []}, "1": {"nodes": ["0"], "edges": [0]}} - def test_invalid_group(self): + def test_invalid_group(self) -> None: medrecord = create_medrecord() # Querying a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.group("0") medrecord.add_group("1", ["0"]) # Querying a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.group(["0", "50"]) - def test_outgoing_edges(self): + def test_outgoing_edges(self) -> None: medrecord = create_medrecord() edges = medrecord.outgoing_edges("0") - self.assertEqual( - sorted([0, 3]), - sorted(edges), - ) + assert sorted([0, 3]) == sorted(edges) edges = medrecord.outgoing_edges(["0", "1"]) - self.assertEqual( - {"0": sorted([0, 3]), "1": [1, 2]}, - {key: sorted(value) for (key, value) in edges.items()}, - ) + assert {key: sorted(value) for key, value in edges.items()} == {"0": sorted([0, 3]), "1": [1, 2]} edges = medrecord.outgoing_edges(node_select().index().is_in(["0", "1"])) - self.assertEqual( - {"0": sorted([0, 3]), "1": [1, 2]}, - {key: sorted(value) for (key, value) in edges.items()}, - ) + assert {key: sorted(value) for key, value in edges.items()} == {"0": sorted([0, 3]), "1": [1, 2]} - def test_invalid_outgoing_edges(self): + def test_invalid_outgoing_edges(self) -> None: medrecord = create_medrecord() # Querying outgoing edges of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.outgoing_edges("50") # Querying outgoing edges of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.outgoing_edges(["0", "50"]) - def test_incoming_edges(self): + def test_incoming_edges(self) -> None: medrecord = create_medrecord() edges = medrecord.incoming_edges("1") - self.assertEqual([0], edges) + assert edges == [0] edges = medrecord.incoming_edges(["1", "2"]) - self.assertEqual({"1": [0], "2": [2]}, edges) + assert edges == {"1": [0], "2": [2]} edges = medrecord.incoming_edges(node_select().index().is_in(["1", "2"])) - self.assertEqual({"1": [0], "2": [2]}, edges) + assert edges == {"1": [0], "2": [2]} - def test_invalid_incoming_edges(self): + def test_invalid_incoming_edges(self) -> None: medrecord = create_medrecord() # Querying incoming edges of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.incoming_edges("50") # Querying incoming edges of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.incoming_edges(["0", "50"]) - def test_edge_endpoints(self): + def test_edge_endpoints(self) -> None: medrecord = create_medrecord() endpoints = medrecord.edge_endpoints(0) - self.assertEqual(("0", "1"), endpoints) + assert endpoints == ("0", "1") endpoints = medrecord.edge_endpoints([0, 1]) - self.assertEqual({0: ("0", "1"), 1: ("1", "0")}, endpoints) + assert endpoints == {0: ("0", "1"), 1: ("1", "0")} endpoints = medrecord.edge_endpoints(edge_select().index().is_in([0, 1])) - self.assertEqual({0: ("0", "1"), 1: ("1", "0")}, endpoints) + assert endpoints == {0: ("0", "1"), 1: ("1", "0")} - def test_invalid_edge_endpoints(self): + def test_invalid_edge_endpoints(self) -> None: medrecord = create_medrecord() # Querying endpoints of a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edge_endpoints(50) # Querying endpoints of a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edge_endpoints([0, 50]) - def test_edges_connecting(self): + def test_edges_connecting(self) -> None: medrecord = create_medrecord() edges = medrecord.edges_connecting("0", "1") - self.assertEqual([0], edges) + assert edges == [0] edges = medrecord.edges_connecting(["0", "1"], "1") - self.assertEqual([0], edges) + assert edges == [0] edges = medrecord.edges_connecting(node_select().index().is_in(["0", "1"]), "1") - self.assertEqual([0], edges) + assert edges == [0] edges = medrecord.edges_connecting("0", ["1", "3"]) - self.assertEqual(sorted([0, 3]), sorted(edges)) + assert sorted([0, 3]) == sorted(edges) edges = medrecord.edges_connecting("0", node_select().index().is_in(["1", "3"])) - self.assertEqual(sorted([0, 3]), sorted(edges)) + assert sorted([0, 3]) == sorted(edges) edges = medrecord.edges_connecting(["0", "1"], ["1", "2", "3"]) - self.assertEqual(sorted([0, 2, 3]), sorted(edges)) + assert sorted([0, 2, 3]) == sorted(edges) edges = medrecord.edges_connecting( node_select().index().is_in(["0", "1"]), node_select().index().is_in(["1", "2", "3"]), ) - self.assertEqual(sorted([0, 2, 3]), sorted(edges)) + assert sorted([0, 2, 3]) == sorted(edges) edges = medrecord.edges_connecting("0", "1", directed=False) - self.assertEqual([0, 1], sorted(edges)) + assert sorted(edges) == [0, 1] - def test_add_node(self): + def test_add_node(self) -> None: medrecord = MedRecord() - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_node("0", {}) - self.assertEqual(1, medrecord.node_count()) + assert medrecord.node_count() == 1 - def test_invalid_add_node(self): + def test_invalid_add_node(self) -> None: medrecord = create_medrecord() - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_node("0", {}) - def test_remove_node(self): + def test_remove_node(self) -> None: medrecord = create_medrecord() - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 attributes = medrecord.remove_node("0") - self.assertEqual(3, medrecord.node_count()) - self.assertEqual(create_nodes()[0][1], attributes) + assert medrecord.node_count() == 3 + assert create_nodes()[0][1] == attributes attributes = medrecord.remove_node(["1", "2"]) - self.assertEqual(1, medrecord.node_count()) - self.assertEqual( - {"1": create_nodes()[1][1], "2": create_nodes()[2][1]}, attributes - ) + assert medrecord.node_count() == 1 + assert attributes == {"1": create_nodes()[1][1], "2": create_nodes()[2][1]} medrecord = create_medrecord() - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 attributes = medrecord.remove_node(node_select().index().is_in(["0", "1"])) - self.assertEqual(2, medrecord.node_count()) - self.assertEqual( - {"0": create_nodes()[0][1], "1": create_nodes()[1][1]}, attributes - ) + assert medrecord.node_count() == 2 + assert attributes == {"0": create_nodes()[0][1], "1": create_nodes()[1][1]} - def test_invalid_remove_node(self): + def test_invalid_remove_node(self) -> None: medrecord = create_medrecord() # Removing a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_node("50") # Removing a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_node(["0", "50"]) - def test_add_nodes(self): + def test_add_nodes(self) -> None: medrecord = MedRecord() - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes(create_nodes()) - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 # Adding pandas dataframe medrecord = MedRecord() - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes((create_pandas_nodes_dataframe(), "index")) - self.assertEqual(2, medrecord.node_count()) + assert medrecord.node_count() == 2 # Adding polars dataframe medrecord = MedRecord() - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 nodes = pl.from_pandas(create_pandas_nodes_dataframe()) medrecord.add_nodes((nodes, "index")) - self.assertEqual(2, medrecord.node_count()) + assert medrecord.node_count() == 2 # Adding multiple pandas dataframes medrecord = MedRecord() - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes( [ @@ -562,14 +547,14 @@ def test_add_nodes(self): ] ) - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 # Adding multiple polars dataframes medrecord = MedRecord() second_nodes = pl.from_pandas(create_second_pandas_nodes_dataframe()) - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes( [ @@ -578,80 +563,80 @@ def test_add_nodes(self): ] ) - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 - def test_invalid_add_nodes(self): + def test_invalid_add_nodes(self) -> None: medrecord = create_medrecord() - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_nodes(create_nodes()) - def test_add_nodes_pandas(self): + def test_add_nodes_pandas(self) -> None: medrecord = MedRecord() nodes = (create_pandas_nodes_dataframe(), "index") - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes(nodes) - self.assertEqual(2, medrecord.node_count()) + assert medrecord.node_count() == 2 medrecord = MedRecord() second_nodes = (create_second_pandas_nodes_dataframe(), "index") - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes([nodes, second_nodes]) - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 - def test_add_nodes_polars(self): + def test_add_nodes_polars(self) -> None: medrecord = MedRecord() nodes = pl.from_pandas(create_pandas_nodes_dataframe()) - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes_polars((nodes, "index")) - self.assertEqual(2, medrecord.node_count()) + assert medrecord.node_count() == 2 medrecord = MedRecord() second_nodes = pl.from_pandas(create_second_pandas_nodes_dataframe()) - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_nodes_polars([(nodes, "index"), (second_nodes, "index")]) - self.assertEqual(4, medrecord.node_count()) + assert medrecord.node_count() == 4 - def test_invalid_add_nodes_polars(self): + def test_invalid_add_nodes_polars(self) -> None: medrecord = MedRecord() nodes = pl.from_pandas(create_pandas_nodes_dataframe()) second_nodes = pl.from_pandas(create_second_pandas_nodes_dataframe()) # Adding a nodes dataframe with the wrong index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): medrecord.add_nodes_polars((nodes, "invalid")) # Adding a nodes dataframe with the wrong index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): medrecord.add_nodes_polars([(nodes, "index"), (second_nodes, "invalid")]) - def test_add_edge(self): + def test_add_edge(self) -> None: medrecord = create_medrecord() - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.edge_count() == 4 medrecord.add_edge("0", "3", {}) - self.assertEqual(5, medrecord.edge_count()) + assert medrecord.edge_count() == 5 - def test_invalid_add_edge(self): + def test_invalid_add_edge(self) -> None: medrecord = MedRecord() nodes = create_nodes() @@ -659,87 +644,87 @@ def test_invalid_add_edge(self): medrecord.add_nodes(nodes) # Adding an edge pointing to a non-existent node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_edge("0", "50", {}) # Adding an edge from a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_edge("50", "0", {}) - def test_remove_edge(self): + def test_remove_edge(self) -> None: medrecord = create_medrecord() - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.edge_count() == 4 attributes = medrecord.remove_edge(0) - self.assertEqual(3, medrecord.edge_count()) - self.assertEqual(create_edges()[0][2], attributes) + assert medrecord.edge_count() == 3 + assert create_edges()[0][2] == attributes attributes = medrecord.remove_edge([1, 2]) - self.assertEqual(1, medrecord.edge_count()) - self.assertEqual({1: create_edges()[1][2], 2: create_edges()[2][2]}, attributes) + assert medrecord.edge_count() == 1 + assert attributes == {1: create_edges()[1][2], 2: create_edges()[2][2]} medrecord = create_medrecord() - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.edge_count() == 4 attributes = medrecord.remove_edge(edge_select().index().is_in([0, 1])) - self.assertEqual(2, medrecord.edge_count()) - self.assertEqual({0: create_edges()[0][2], 1: create_edges()[1][2]}, attributes) + assert medrecord.edge_count() == 2 + assert attributes == {0: create_edges()[0][2], 1: create_edges()[1][2]} - def test_invalid_remove_edge(self): + def test_invalid_remove_edge(self) -> None: medrecord = create_medrecord() # Removing a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_edge(50) - def test_add_edges(self): + def test_add_edges(self) -> None: medrecord = MedRecord() nodes = create_nodes() medrecord.add_nodes(nodes) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edges(create_edges()) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.edge_count() == 4 # Adding pandas dataframe medrecord = MedRecord() medrecord.add_nodes(nodes) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edges((create_pandas_edges_dataframe(), "source", "target")) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.edge_count() == 2 # Adding polars dataframe medrecord = MedRecord() medrecord.add_nodes(nodes) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 edges = pl.from_pandas(create_pandas_edges_dataframe()) medrecord.add_edges((edges, "source", "target")) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.edge_count() == 2 # Adding multiple pandas dataframe medrecord = MedRecord() medrecord.add_nodes(nodes) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edges( [ @@ -748,14 +733,14 @@ def test_add_edges(self): ] ) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.edge_count() == 4 # Adding multiple polats dataframe medrecord = MedRecord() medrecord.add_nodes(nodes) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 second_edges = pl.from_pandas(create_second_pandas_edges_dataframe()) @@ -766,9 +751,9 @@ def test_add_edges(self): ] ) - self.assertEqual(4, medrecord.edge_count()) + assert medrecord.edge_count() == 4 - def test_add_edges_pandas(self): + def test_add_edges_pandas(self) -> None: medrecord = MedRecord() nodes = create_nodes() @@ -777,13 +762,13 @@ def test_add_edges_pandas(self): edges = (create_pandas_edges_dataframe(), "source", "target") - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edges(edges) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.edge_count() == 2 - def test_add_edges_polars(self): + def test_add_edges_polars(self) -> None: medrecord = MedRecord() nodes = create_nodes() @@ -792,13 +777,13 @@ def test_add_edges_polars(self): edges = pl.from_pandas(create_pandas_edges_dataframe()) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edges_polars((edges, "source", "target")) - self.assertEqual(2, medrecord.edge_count()) + assert medrecord.edge_count() == 2 - def test_invalid_add_edges_polars(self): + def test_invalid_add_edges_polars(self) -> None: medrecord = MedRecord() nodes = create_nodes() @@ -808,33 +793,33 @@ def test_invalid_add_edges_polars(self): edges = pl.from_pandas(create_pandas_edges_dataframe()) # Providing the wrong source index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): medrecord.add_edges_polars((edges, "invalid", "target")) # Providing the wrong target index column name should fail - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): medrecord.add_edges_polars((edges, "source", "invalid")) - def test_add_group(self): + def test_add_group(self) -> None: medrecord = create_medrecord() - self.assertEqual(0, medrecord.group_count()) + assert medrecord.group_count() == 0 medrecord.add_group("0") - self.assertEqual(1, medrecord.group_count()) + assert medrecord.group_count() == 1 medrecord.add_group("1", "0", 0) - self.assertEqual(2, medrecord.group_count()) - self.assertEqual({"nodes": ["0"], "edges": [0]}, medrecord.group("1")) + assert medrecord.group_count() == 2 + assert medrecord.group("1") == {"nodes": ["0"], "edges": [0]} medrecord.add_group("2", ["0", "1"], [0, 1]) - self.assertEqual(3, medrecord.group_count()) + assert medrecord.group_count() == 3 nodes_and_edges = medrecord.group("2") - self.assertEqual(sorted(["0", "1"]), sorted(nodes_and_edges["nodes"])) - self.assertEqual(sorted([0, 1]), sorted(nodes_and_edges["edges"])) + assert sorted(["0", "1"]) == sorted(nodes_and_edges["nodes"]) + assert sorted([0, 1]) == sorted(nodes_and_edges["edges"]) medrecord.add_group( "3", @@ -842,511 +827,451 @@ def test_add_group(self): edge_select().index().is_in([0, 1]), ) - self.assertEqual(4, medrecord.group_count()) + assert medrecord.group_count() == 4 nodes_and_edges = medrecord.group("3") - self.assertEqual(sorted(["0", "1"]), sorted(nodes_and_edges["nodes"])) - self.assertEqual(sorted([0, 1]), sorted(nodes_and_edges["edges"])) + assert sorted(["0", "1"]) == sorted(nodes_and_edges["nodes"]) + assert sorted([0, 1]) == sorted(nodes_and_edges["edges"]) - def test_invalid_add_group(self): + def test_invalid_add_group(self) -> None: medrecord = create_medrecord() # Adding a group with a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_group("0", "50") # Adding an already existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_group("0", ["0", "50"]) medrecord.add_group("0", "0") # Adding an already existing group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_group("0") # Adding a node to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_group("0", "0") # Adding a node to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_group("0", ["1", "0"]) # Adding a node to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_group("0", node_select().index() == "0") - def test_remove_group(self): + def test_remove_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0") - self.assertEqual(1, medrecord.group_count()) + assert medrecord.group_count() == 1 medrecord.remove_group("0") - self.assertEqual(0, medrecord.group_count()) + assert medrecord.group_count() == 0 - def test_invalid_remove_group(self): + def test_invalid_remove_group(self) -> None: medrecord = create_medrecord() # Removing a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_group("0") - def test_add_node_to_group(self): + def test_add_node_to_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0") - self.assertEqual([], medrecord.nodes_in_group("0")) + assert medrecord.nodes_in_group("0") == [] medrecord.add_node_to_group("0", "0") - self.assertEqual(["0"], medrecord.nodes_in_group("0")) + assert medrecord.nodes_in_group("0") == ["0"] medrecord.add_node_to_group("0", ["1", "2"]) - self.assertEqual( - sorted(["0", "1", "2"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1", "2"]) == sorted(medrecord.nodes_in_group("0")) medrecord.add_node_to_group("0", node_select().index() == "3") - self.assertEqual( - sorted(["0", "1", "2", "3"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1", "2", "3"]) == sorted(medrecord.nodes_in_group("0")) - def test_invalid_add_node_to_group(self): + def test_invalid_add_node_to_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", ["0"]) # Adding to a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_node_to_group("50", "1") # Adding to a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_node_to_group("50", ["1", "2"]) # Adding a non-existing node to a group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_node_to_group("0", "50") # Adding a non-existing node to a group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_node_to_group("0", ["1", "50"]) # Adding a node to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_node_to_group("0", "0") # Adding a node to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_node_to_group("0", ["1", "0"]) # Adding a node to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_node_to_group("0", node_select().index() == "0") - def test_add_edge_to_group(self): + def test_add_edge_to_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0") - self.assertEqual([], medrecord.edges_in_group("0")) + assert medrecord.edges_in_group("0") == [] medrecord.add_edge_to_group("0", 0) - self.assertEqual([0], medrecord.edges_in_group("0")) + assert medrecord.edges_in_group("0") == [0] medrecord.add_edge_to_group("0", [1, 2]) - self.assertEqual( - sorted([0, 1, 2]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1, 2]) == sorted(medrecord.edges_in_group("0")) medrecord.add_edge_to_group("0", edge_select().index() == 3) - self.assertEqual( - sorted([0, 1, 2, 3]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1, 2, 3]) == sorted(medrecord.edges_in_group("0")) - def test_invalid_add_edge_to_group(self): + def test_invalid_add_edge_to_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", edges=[0]) # Adding to a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_edge_to_group("50", 1) # Adding to a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_edge_to_group("50", [1, 2]) # Adding a non-existing edge to a group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_edge_to_group("0", 50) # Adding a non-existing edge to a group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.add_edge_to_group("0", [1, 50]) # Adding an edge to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_edge_to_group("0", 0) # Adding an edge to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_edge_to_group("0", [1, 0]) # Adding an edge to a group that already is in the group should fail - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): medrecord.add_edge_to_group("0", edge_select().index() == 0) - def test_remove_node_from_group(self): + def test_remove_node_from_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", ["0", "1"]) - self.assertEqual( - sorted(["0", "1"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1"]) == sorted(medrecord.nodes_in_group("0")) medrecord.remove_node_from_group("0", "1") - self.assertEqual(["0"], medrecord.nodes_in_group("0")) + assert medrecord.nodes_in_group("0") == ["0"] medrecord.add_node_to_group("0", "1") - self.assertEqual( - sorted(["0", "1"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1"]) == sorted(medrecord.nodes_in_group("0")) medrecord.remove_node_from_group("0", ["0", "1"]) - self.assertEqual([], medrecord.nodes_in_group("0")) + assert medrecord.nodes_in_group("0") == [] medrecord.add_node_to_group("0", ["0", "1"]) - self.assertEqual( - sorted(["0", "1"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1"]) == sorted(medrecord.nodes_in_group("0")) medrecord.remove_node_from_group("0", node_select().index().is_in(["0", "1"])) - self.assertEqual([], medrecord.nodes_in_group("0")) + assert medrecord.nodes_in_group("0") == [] - def test_invalid_remove_node_from_group(self): + def test_invalid_remove_node_from_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", ["0", "1"]) # Removing a node from a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_node_from_group("50", "0") # Removing a node from a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_node_from_group("50", ["0", "1"]) # Removing a node from a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_node_from_group("50", node_select().index() == "0") # Removing a non-existing node from a group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_node_from_group("0", "50") # Removing a non-existing node from a group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_node_from_group("0", ["0", "50"]) - def test_remove_edge_from_group(self): + def test_remove_edge_from_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", edges=[0, 1]) - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1]) == sorted(medrecord.edges_in_group("0")) medrecord.remove_edge_from_group("0", 1) - self.assertEqual([0], medrecord.edges_in_group("0")) + assert medrecord.edges_in_group("0") == [0] medrecord.add_edge_to_group("0", 1) - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1]) == sorted(medrecord.edges_in_group("0")) medrecord.remove_edge_from_group("0", [0, 1]) - self.assertEqual([], medrecord.edges_in_group("0")) + assert medrecord.edges_in_group("0") == [] medrecord.add_edge_to_group("0", [0, 1]) - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1]) == sorted(medrecord.edges_in_group("0")) medrecord.remove_edge_from_group("0", edge_select().index().is_in([0, 1])) - self.assertEqual([], medrecord.edges_in_group("0")) + assert medrecord.edges_in_group("0") == [] - def test_invalid_remove_edge_from_group(self): + def test_invalid_remove_edge_from_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", edges=[0, 1]) # Removing an edge from a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_edge_from_group("50", 0) # Removing an edge from a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_edge_from_group("50", [0, 1]) # Removing an edge from a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_edge_from_group("50", edge_select().index() == 0) # Removing a non-existing edge from a group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_edge_from_group("0", 50) # Removing a non-existing edge from a group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.remove_edge_from_group("0", [0, 50]) - def test_nodes_in_group(self): + def test_nodes_in_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", ["0", "1"]) - self.assertEqual( - sorted(["0", "1"]), - sorted(medrecord.nodes_in_group("0")), - ) + assert sorted(["0", "1"]) == sorted(medrecord.nodes_in_group("0")) - def test_invalid_nodes_in_group(self): + def test_invalid_nodes_in_group(self) -> None: medrecord = create_medrecord() # Querying nodes in a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.nodes_in_group("50") - def test_edges_in_group(self): + def test_edges_in_group(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", edges=[0, 1]) - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.edges_in_group("0")), - ) + assert sorted([0, 1]) == sorted(medrecord.edges_in_group("0")) - def test_invalid_edges_in_group(self): + def test_invalid_edges_in_group(self) -> None: medrecord = create_medrecord() # Querying edges in a non-existing group should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.edges_in_group("50") - def test_groups_of_node(self): + def test_groups_of_node(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", ["0", "1"]) - self.assertEqual(["0"], medrecord.groups_of_node("0")) + assert medrecord.groups_of_node("0") == ["0"] - self.assertEqual({"0": ["0"], "1": ["0"]}, medrecord.groups_of_node(["0", "1"])) + assert medrecord.groups_of_node(["0", "1"]) == {"0": ["0"], "1": ["0"]} - self.assertEqual( - {"0": ["0"], "1": ["0"]}, - medrecord.groups_of_node(node_select().index().is_in(["0", "1"])), - ) + assert medrecord.groups_of_node(node_select().index().is_in(["0", "1"])) == {"0": ["0"], "1": ["0"]} - def test_invalid_groups_of_node(self): + def test_invalid_groups_of_node(self) -> None: medrecord = create_medrecord() # Querying groups of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.groups_of_node("50") # Querying groups of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.groups_of_node(["0", "50"]) - def test_groups_of_edge(self): + def test_groups_of_edge(self) -> None: medrecord = create_medrecord() medrecord.add_group("0", edges=[0, 1]) - self.assertEqual(["0"], medrecord.groups_of_edge(0)) + assert medrecord.groups_of_edge(0) == ["0"] - self.assertEqual({0: ["0"], 1: ["0"]}, medrecord.groups_of_edge([0, 1])) + assert medrecord.groups_of_edge([0, 1]) == {0: ["0"], 1: ["0"]} - self.assertEqual( - {0: ["0"], 1: ["0"]}, - medrecord.groups_of_edge(edge_select().index().is_in([0, 1])), - ) + assert medrecord.groups_of_edge(edge_select().index().is_in([0, 1])) == {0: ["0"], 1: ["0"]} - def test_invalid_groups_of_edge(self): + def test_invalid_groups_of_edge(self) -> None: medrecord = create_medrecord() # Querying groups of a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.groups_of_edge(50) # Querying groups of a non-existing edge should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.groups_of_edge([0, 50]) - def test_node_count(self): + def test_node_count(self) -> None: medrecord = MedRecord() - self.assertEqual(0, medrecord.node_count()) + assert medrecord.node_count() == 0 medrecord.add_node("0", {}) - self.assertEqual(1, medrecord.node_count()) + assert medrecord.node_count() == 1 - def test_edge_count(self): + def test_edge_count(self) -> None: medrecord = MedRecord() medrecord.add_node("0", {}) medrecord.add_node("1", {}) - self.assertEqual(0, medrecord.edge_count()) + assert medrecord.edge_count() == 0 medrecord.add_edge("0", "1", {}) - self.assertEqual(1, medrecord.edge_count()) + assert medrecord.edge_count() == 1 - def test_group_count(self): + def test_group_count(self) -> None: medrecord = create_medrecord() - self.assertEqual(0, medrecord.group_count()) + assert medrecord.group_count() == 0 medrecord.add_group("0") - self.assertEqual(1, medrecord.group_count()) + assert medrecord.group_count() == 1 - def test_contains_node(self): + def test_contains_node(self) -> None: medrecord = create_medrecord() - self.assertTrue(medrecord.contains_node("0")) + assert medrecord.contains_node("0") - self.assertFalse(medrecord.contains_node("50")) + assert not medrecord.contains_node("50") - def test_contains_edge(self): + def test_contains_edge(self) -> None: medrecord = create_medrecord() - self.assertTrue(medrecord.contains_edge(0)) + assert medrecord.contains_edge(0) - self.assertFalse(medrecord.contains_edge(50)) + assert not medrecord.contains_edge(50) - def test_contains_group(self): + def test_contains_group(self) -> None: medrecord = create_medrecord() - self.assertFalse(medrecord.contains_group("0")) + assert not medrecord.contains_group("0") medrecord.add_group("0") - self.assertTrue(medrecord.contains_group("0")) + assert medrecord.contains_group("0") - def test_neighbors(self): + def test_neighbors(self) -> None: medrecord = create_medrecord() neighbors = medrecord.neighbors("0") - self.assertEqual( - sorted(["1", "3"]), - sorted(neighbors), - ) + assert sorted(["1", "3"]) == sorted(neighbors) neighbors = medrecord.neighbors(["0", "1"]) - self.assertEqual( - {"0": sorted(["1", "3"]), "1": ["0", "2"]}, - {key: sorted(value) for (key, value) in neighbors.items()}, - ) + assert {key: sorted(value) for key, value in neighbors.items()} == {"0": sorted(["1", "3"]), "1": ["0", "2"]} neighbors = medrecord.neighbors(node_select().index().is_in(["0", "1"])) - self.assertEqual( - {"0": sorted(["1", "3"]), "1": ["0", "2"]}, - {key: sorted(value) for (key, value) in neighbors.items()}, - ) + assert {key: sorted(value) for key, value in neighbors.items()} == {"0": sorted(["1", "3"]), "1": ["0", "2"]} neighbors = medrecord.neighbors("0", directed=False) - self.assertEqual( - sorted(["1", "3"]), - sorted(neighbors), - ) + assert sorted(["1", "3"]) == sorted(neighbors) neighbors = medrecord.neighbors(["0", "1"], directed=False) - self.assertEqual( - {"0": sorted(["1", "3"]), "1": ["0", "2"]}, - {key: sorted(value) for (key, value) in neighbors.items()}, - ) + assert {key: sorted(value) for key, value in neighbors.items()} == {"0": sorted(["1", "3"]), "1": ["0", "2"]} neighbors = medrecord.neighbors( node_select().index().is_in(["0", "1"]), directed=False ) - self.assertEqual( - {"0": sorted(["1", "3"]), "1": ["0", "2"]}, - {key: sorted(value) for (key, value) in neighbors.items()}, - ) + assert {key: sorted(value) for key, value in neighbors.items()} == {"0": sorted(["1", "3"]), "1": ["0", "2"]} - def test_invalid_neighbors(self): + def test_invalid_neighbors(self) -> None: medrecord = create_medrecord() # Querying neighbors of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.neighbors("50") # Querying neighbors of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.neighbors(["0", "50"]) # Querying undirected neighbors of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.neighbors("50", directed=False) # Querying undirected neighbors of a non-existing node should fail - with self.assertRaises(IndexError): + with pytest.raises(IndexError): medrecord.neighbors(["0", "50"], directed=False) - def test_clear(self): + def test_clear(self) -> None: medrecord = create_medrecord() - self.assertEqual(4, medrecord.node_count()) - self.assertEqual(4, medrecord.edge_count()) - self.assertEqual(0, medrecord.group_count()) + assert medrecord.node_count() == 4 + assert medrecord.edge_count() == 4 + assert medrecord.group_count() == 0 medrecord.clear() - self.assertEqual(0, medrecord.node_count()) - self.assertEqual(0, medrecord.edge_count()) - self.assertEqual(0, medrecord.group_count()) + assert medrecord.node_count() == 0 + assert medrecord.edge_count() == 0 + assert medrecord.group_count() == 0 if __name__ == "__main__": diff --git a/medmodels/medrecord/tests/test_querying.py b/medmodels/medrecord/tests/test_querying.py index 808961ab..d7b5b7e2 100644 --- a/medmodels/medrecord/tests/test_querying.py +++ b/medmodels/medrecord/tests/test_querying.py @@ -38,1133 +38,399 @@ def create_medrecord() -> MedRecord: class TestMedRecord(unittest.TestCase): - def test_select_nodes_node(self): + def test_select_nodes_node(self) -> None: medrecord = create_medrecord() medrecord.add_group("test", ["0"]) # Node in group - self.assertEqual(["0"], medrecord.select_nodes(node().in_group("test"))) + assert medrecord.select_nodes(node().in_group("test")) == ["0"] # Node has attribute - self.assertEqual(["0"], medrecord.select_nodes(node().has_attribute("lorem"))) + assert medrecord.select_nodes(node().has_attribute("lorem")) == ["0"] # Node has outgoing edge with - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().has_outgoing_edge_with(edge().index().equal(0)) - ), - ) + assert medrecord.select_nodes(node().has_outgoing_edge_with(edge().index().equal(0))) == ["0"] # Node has incoming edge with - self.assertEqual( - ["1"], - medrecord.select_nodes( - node().has_incoming_edge_with(edge().index().equal(0)) - ), - ) + assert medrecord.select_nodes(node().has_incoming_edge_with(edge().index().equal(0))) == ["1"] # Node has edge with - self.assertEqual( - sorted(["0", "1"]), - sorted( - medrecord.select_nodes(node().has_edge_with(edge().index().equal(0))) - ), - ) + assert sorted(["0", "1"]) == sorted(medrecord.select_nodes(node().has_edge_with(edge().index().equal(0)))) # Node has neighbor with - self.assertEqual( - sorted(["0", "1"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("2")) - ) - ), - ) - self.assertEqual( - sorted(["0"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("1"), directed=True) - ) - ), - ) + assert sorted(["0", "1"]) == sorted(medrecord.select_nodes(node().has_neighbor_with(node().index().equal("2")))) + assert sorted(["0"]) == sorted(medrecord.select_nodes(node().has_neighbor_with(node().index().equal("1"), directed=True))) # Node has neighbor with - self.assertEqual( - sorted(["0", "2"]), - sorted( - medrecord.select_nodes( - node().has_neighbor_with(node().index().equal("1"), directed=False) - ) - ), - ) - - def test_select_nodes_node_index(self): + assert sorted(["0", "2"]) == sorted(medrecord.select_nodes(node().has_neighbor_with(node().index().equal("1"), directed=False))) + + def test_select_nodes_node_index(self) -> None: medrecord = create_medrecord() # Index greater - self.assertEqual( - sorted(["2", "3"]), - sorted(medrecord.select_nodes(node().index().greater("1"))), - ) + assert sorted(["2", "3"]) == sorted(medrecord.select_nodes(node().index().greater("1"))) # Index less - self.assertEqual( - sorted(["0", "1"]), sorted(medrecord.select_nodes(node().index().less("2"))) - ) + assert sorted(["0", "1"]) == sorted(medrecord.select_nodes(node().index().less("2"))) # Index greater or equal - self.assertEqual( - sorted(["1", "2", "3"]), - sorted(medrecord.select_nodes(node().index().greater_or_equal("1"))), - ) + assert sorted(["1", "2", "3"]) == sorted(medrecord.select_nodes(node().index().greater_or_equal("1"))) # Index less or equal - self.assertEqual( - sorted(["0", "1", "2"]), - sorted(medrecord.select_nodes(node().index().less_or_equal("2"))), - ) + assert sorted(["0", "1", "2"]) == sorted(medrecord.select_nodes(node().index().less_or_equal("2"))) # Index equal - self.assertEqual(["1"], medrecord.select_nodes(node().index().equal("1"))) + assert medrecord.select_nodes(node().index().equal("1")) == ["1"] # Index not equal - self.assertEqual( - sorted(["0", "2", "3"]), - sorted(medrecord.select_nodes(node().index().not_equal("1"))), - ) + assert sorted(["0", "2", "3"]) == sorted(medrecord.select_nodes(node().index().not_equal("1"))) # Index in - self.assertEqual(["1"], medrecord.select_nodes(node().index().is_in(["1"]))) + assert medrecord.select_nodes(node().index().is_in(["1"])) == ["1"] # Index not in - self.assertEqual( - sorted(["0", "2", "3"]), - sorted(medrecord.select_nodes(node().index().not_in(["1"]))), - ) + assert sorted(["0", "2", "3"]) == sorted(medrecord.select_nodes(node().index().not_in(["1"]))) # Index starts with - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().starts_with("1")), - ) + assert medrecord.select_nodes(node().index().starts_with("1")) == ["1"] # Index ends with - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().ends_with("1")), - ) + assert medrecord.select_nodes(node().index().ends_with("1")) == ["1"] # Index contains - self.assertEqual( - ["1"], - medrecord.select_nodes(node().index().contains("1")), - ) + assert medrecord.select_nodes(node().index().contains("1")) == ["1"] - def test_select_nodes_node_attribute(self): + def test_select_nodes_node_attribute(self) -> None: medrecord = create_medrecord() # Attribute greater - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").greater("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") > "ipsum") - ) + assert medrecord.select_nodes(node().attribute("lorem").greater("ipsum")) == [] + assert medrecord.select_nodes(node().attribute("lorem") > "ipsum") == [] # Attribute less - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").less("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") < "ipsum") - ) + assert medrecord.select_nodes(node().attribute("lorem").less("ipsum")) == [] + assert medrecord.select_nodes(node().attribute("lorem") < "ipsum") == [] # Attribute greater or equal - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").greater_or_equal("ipsum")), - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") >= "ipsum") - ) + assert medrecord.select_nodes(node().attribute("lorem").greater_or_equal("ipsum")) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem") >= "ipsum") == ["0"] # Attribute less or equal - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").less_or_equal("ipsum")), - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") <= "ipsum") - ) + assert medrecord.select_nodes(node().attribute("lorem").less_or_equal("ipsum")) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem") <= "ipsum") == ["0"] # Attribute equal - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem").equal("ipsum")) - ) - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem") == "ipsum") - ) + assert medrecord.select_nodes(node().attribute("lorem").equal("ipsum")) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem") == "ipsum") == ["0"] # Attribute not equal - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").not_equal("ipsum")) - ) - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem") != "ipsum") - ) + assert medrecord.select_nodes(node().attribute("lorem").not_equal("ipsum")) == [] + assert medrecord.select_nodes(node().attribute("lorem") != "ipsum") == [] # Attribute in - self.assertEqual( - ["0"], medrecord.select_nodes(node().attribute("lorem").is_in(["ipsum"])) - ) + assert medrecord.select_nodes(node().attribute("lorem").is_in(["ipsum"])) == ["0"] # Attribute not in - self.assertEqual( - [], medrecord.select_nodes(node().attribute("lorem").not_in(["ipsum"])) - ) + assert medrecord.select_nodes(node().attribute("lorem").not_in(["ipsum"])) == [] # Attribute starts with - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").starts_with("ip")), - ) + assert medrecord.select_nodes(node().attribute("lorem").starts_with("ip")) == ["0"] # Attribute ends with - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").ends_with("um")), - ) + assert medrecord.select_nodes(node().attribute("lorem").ends_with("um")) == ["0"] # Attribute contains - self.assertEqual( - ["0"], - medrecord.select_nodes(node().attribute("lorem").contains("su")), - ) + assert medrecord.select_nodes(node().attribute("lorem").contains("su")) == ["0"] # Attribute compare to attribute - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem")) - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem"))) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem"))) == [] # Attribute compare to attribute add - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").add("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") + "10" - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").add("10")) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") + "10" - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").add("10"))) == [] + assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") + "10") == [] + assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").add("10"))) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") + "10") == ["0"] # Attribute compare to attribute sub # Returns nothing because can't sub a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") + "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") - "10" - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").sub("10"))) == [] + assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") + "10") == [] + assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").sub("10"))) == [] + assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") - "10") == [] # Attribute compare to attribute sub - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").sub(10)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").sub(10)) - ), - ) + assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").sub(10))) == [] + assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").sub(10))) == ["0"] # Attribute compare to attribute mul - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").mul(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") * 2 - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").mul(2)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") * 2 - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").mul(2))) == [] + assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") * 2) == [] + assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").mul(2))) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") * 2) == ["0"] # Attribute compare to attribute div # Returns nothing because can't div a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") / "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") / "10" - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").div("10"))) == [] + assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") / "10") == [] + assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").div("10"))) == [] + assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") / "10") == [] # Attribute compare to attribute div - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").div(2)) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").div(2)) - ), - ) + assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").div(2))) == [] + assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").div(2))) == ["0"] # Attribute compare to attribute pow # Returns nothing because can't pow a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") ** "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") ** "10" - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").pow("10"))) == [] + assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") ** "10") == [] + assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").pow("10"))) == [] + assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") ** "10") == [] # Attribute compare to attribute pow - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").pow(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").pow(2)) - ), - ) + assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").pow(2))) == ["0"] + assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").pow(2))) == [] # Attribute compare to attribute mod # Returns nothing because can't mod a string - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") == node().attribute("lorem") % "10" - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem") != node().attribute("lorem") % "10" - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").mod("10"))) == [] + assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") % "10") == [] + assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").mod("10"))) == [] + assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") % "10") == [] # Attribute compare to attribute mod - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").mod(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").mod(2)) - ), - ) + assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").mod(2))) == ["0"] + assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").mod(2))) == [] # Attribute compare to attribute round - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("lorem").round()) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").not_equal(node().attribute("lorem").round()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("float").round()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").round()) - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").round())) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").round())) == [] + assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("float").round())) == ["0"] + assert medrecord.select_nodes(node().attribute("float").not_equal(node().attribute("float").round())) == ["0"] # Attribute compare to attribute round - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("float").ceil()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").ceil()) - ), - ) + assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("float").ceil())) == ["0"] + assert medrecord.select_nodes(node().attribute("float").not_equal(node().attribute("float").ceil())) == ["0"] # Attribute compare to attribute floor - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("float").floor()) - ), - ) - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("float").not_equal(node().attribute("float").floor()) - ), - ) + assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("float").floor())) == [] + assert medrecord.select_nodes(node().attribute("float").not_equal(node().attribute("float").floor())) == ["0"] # Attribute compare to attribute abs - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").abs()) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("integer").not_equal(node().attribute("integer").abs()) - ), - ) + assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").abs())) == ["0"] + assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").abs())) == [] # Attribute compare to attribute sqrt - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("integer").equal(node().attribute("integer").sqrt()) - ), - ) - self.assertEqual( - [], - medrecord.select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").sqrt()) - ), - ) + assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").sqrt())) == ["0"] + assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").sqrt())) == [] # Attribute compare to attribute trim - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("dolor").trim()) - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("dolor").trim())) == ["0"] # Attribute compare to attribute trim_start - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("dolor").trim_start()) - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("dolor").trim_start())) == [] # Attribute compare to attribute trim_end - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("dolor").trim_end()) - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("dolor").trim_end())) == [] # Attribute compare to attribute lowercase - self.assertEqual( - ["0"], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("test").lowercase()) - ), - ) + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("test").lowercase())) == ["0"] # Attribute compare to attribute uppercase - self.assertEqual( - [], - medrecord.select_nodes( - node().attribute("lorem").equal(node().attribute("test").uppercase()) - ), - ) - - def test_select_edges_edge(self): + assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("test").uppercase())) == [] + + def test_select_edges_edge(self) -> None: medrecord = create_medrecord() medrecord.add_group("test", edges=[0]) # Edge connected to target - self.assertEqual( - sorted([1, 2, 3]), - sorted(medrecord.select_edges(edge().connected_target("2"))), - ) + assert sorted([1, 2, 3]) == sorted(medrecord.select_edges(edge().connected_target("2"))) # Edge connected to source - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().connected_source("0"))), - ) + assert sorted([0, 2, 3]) == sorted(medrecord.select_edges(edge().connected_source("0"))) # Edge connected - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.select_edges(edge().connected("1"))), - ) + assert sorted([0, 1]) == sorted(medrecord.select_edges(edge().connected("1"))) # Edge in group - self.assertEqual( - [0], - medrecord.select_edges(edge().in_group("test")), - ) + assert medrecord.select_edges(edge().in_group("test")) == [0] # Edge has attribute - self.assertEqual( - [0], - medrecord.select_edges(edge().has_attribute("sed")), - ) + assert medrecord.select_edges(edge().has_attribute("sed")) == [0] # Edge connected to target with - self.assertEqual( - [0], - medrecord.select_edges( - edge().connected_target_with(node().index().equal("1")) - ), - ) + assert medrecord.select_edges(edge().connected_target_with(node().index().equal("1"))) == [0] # Edge connected to source with - self.assertEqual( - sorted([0, 2, 3]), - sorted( - medrecord.select_edges( - edge().connected_source_with(node().index().equal("0")) - ) - ), - ) + assert sorted([0, 2, 3]) == sorted(medrecord.select_edges(edge().connected_source_with(node().index().equal("0")))) # Edge connected with - self.assertEqual( - sorted([0, 1]), - sorted( - medrecord.select_edges(edge().connected_with(node().index().equal("1"))) - ), - ) + assert sorted([0, 1]) == sorted(medrecord.select_edges(edge().connected_with(node().index().equal("1")))) # Edge has parallel edges with - self.assertEqual( - sorted([2, 3]), - sorted( - medrecord.select_edges( - edge().has_parallel_edges_with(edge().has_attribute("test")) - ) - ), - ) + assert sorted([2, 3]) == sorted(medrecord.select_edges(edge().has_parallel_edges_with(edge().has_attribute("test")))) # Edge has parallel edges with self comparison - self.assertEqual( - [2], - medrecord.select_edges( - edge().has_parallel_edges_with_self_comparison( - edge().attribute("test").equal(edge().attribute("test").sub(1)) - ) - ), - ) - - def test_select_edges_edge_index(self): + assert medrecord.select_edges(edge().has_parallel_edges_with_self_comparison(edge().attribute("test").equal(edge().attribute("test").sub(1)))) == [2] + + def test_select_edges_edge_index(self) -> None: medrecord = create_medrecord() # Index greater - self.assertEqual( - sorted([2, 3]), - sorted(medrecord.select_edges(edge().index().greater(1))), - ) + assert sorted([2, 3]) == sorted(medrecord.select_edges(edge().index().greater(1))) # Index less - self.assertEqual( - [0], - medrecord.select_edges(edge().index().less(1)), - ) + assert medrecord.select_edges(edge().index().less(1)) == [0] # Index greater or equal - self.assertEqual( - sorted([1, 2, 3]), - sorted(medrecord.select_edges(edge().index().greater_or_equal(1))), - ) + assert sorted([1, 2, 3]) == sorted(medrecord.select_edges(edge().index().greater_or_equal(1))) # Index less or equal - self.assertEqual( - sorted([0, 1]), - sorted(medrecord.select_edges(edge().index().less_or_equal(1))), - ) + assert sorted([0, 1]) == sorted(medrecord.select_edges(edge().index().less_or_equal(1))) # Index equal - self.assertEqual( - [1], - medrecord.select_edges(edge().index().equal(1)), - ) + assert medrecord.select_edges(edge().index().equal(1)) == [1] # Index not equal - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().index().not_equal(1))), - ) + assert sorted([0, 2, 3]) == sorted(medrecord.select_edges(edge().index().not_equal(1))) # Index in - self.assertEqual( - [1], - medrecord.select_edges(edge().index().is_in([1])), - ) + assert medrecord.select_edges(edge().index().is_in([1])) == [1] # Index not in - self.assertEqual( - sorted([0, 2, 3]), - sorted(medrecord.select_edges(edge().index().not_in([1]))), - ) + assert sorted([0, 2, 3]) == sorted(medrecord.select_edges(edge().index().not_in([1]))) - def test_select_edges_edges_attribute(self): + def test_select_edges_edges_attribute(self) -> None: medrecord = create_medrecord() # Attribute greater - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").greater("do")), - ) + assert medrecord.select_edges(edge().attribute("sed").greater("do")) == [] # Attribute less - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").less("do")), - ) + assert medrecord.select_edges(edge().attribute("sed").less("do")) == [] # Attribute greater or equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").greater_or_equal("do")), - ) + assert medrecord.select_edges(edge().attribute("sed").greater_or_equal("do")) == [0] # Attribute less or equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").less_or_equal("do")), - ) + assert medrecord.select_edges(edge().attribute("sed").less_or_equal("do")) == [0] # Attribute equal - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").equal("do")), - ) + assert medrecord.select_edges(edge().attribute("sed").equal("do")) == [0] # Attribute not equal - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").not_equal("do")), - ) + assert medrecord.select_edges(edge().attribute("sed").not_equal("do")) == [] # Attribute in - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").is_in(["do"])), - ) + assert medrecord.select_edges(edge().attribute("sed").is_in(["do"])) == [0] # Attribute not in - self.assertEqual( - [], - medrecord.select_edges(edge().attribute("sed").not_in(["do"])), - ) + assert medrecord.select_edges(edge().attribute("sed").not_in(["do"])) == [] # Attribute starts with - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").starts_with("d")), - ) + assert medrecord.select_edges(edge().attribute("sed").starts_with("d")) == [0] # Attribute ends with - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").ends_with("o")), - ) + assert medrecord.select_edges(edge().attribute("sed").ends_with("o")) == [0] # Attribute contains - self.assertEqual( - [0], - medrecord.select_edges(edge().attribute("sed").contains("d")), - ) + assert medrecord.select_edges(edge().attribute("sed").contains("d")) == [0] # Attribute compare to attribute - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed")) - ), - ) + assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed"))) == [0] + assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed"))) == [] # Attribute compare to attribute add - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").add("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") + "10" - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").add("10")) - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") + "10" - ), - ) + assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed").add("10"))) == [] + assert medrecord.select_edges(edge().attribute("sed") == edge().attribute("sed") + "10") == [] + assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed").add("10"))) == [0] + assert medrecord.select_edges(edge().attribute("sed") != edge().attribute("sed") + "10") == [0] # Attribute compare to attribute sub # Returns nothing because can't sub a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") - "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").sub("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") - "10" - ), - ) + assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed").sub("10"))) == [] + assert medrecord.select_edges(edge().attribute("sed") == edge().attribute("sed") - "10") == [] + assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed").sub("10"))) == [] + assert medrecord.select_edges(edge().attribute("sed") != edge().attribute("sed") - "10") == [] # Attribute compare to attribute sub - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").sub(10)) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").sub(10)) - ), - ) + assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").sub(10))) == [] + assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").sub(10))) == [2] # Attribute compare to attribute mul - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").mul(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") * 2 - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").mul(2)) - ), - ) - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") * 2 - ), - ) + assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed").mul(2))) == [] + assert medrecord.select_edges(edge().attribute("sed") == edge().attribute("sed") * 2) == [] + assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed").mul(2))) == [0] + assert medrecord.select_edges(edge().attribute("sed") != edge().attribute("sed") * 2) == [0] # Attribute compare to attribute div # Returns nothing because can't div a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") == edge().attribute("sed") / "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").div("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed") != edge().attribute("sed") / "10" - ), - ) + assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed").div("10"))) == [] + assert medrecord.select_edges(edge().attribute("sed") == edge().attribute("sed") / "10") == [] + assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed").div("10"))) == [] + assert medrecord.select_edges(edge().attribute("sed") != edge().attribute("sed") / "10") == [] # Attribute compare to attribute div - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").div(2)) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").div(2)) - ), - ) + assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").div(2))) == [] + assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").div(2))) == [2] # Attribute compare to attribute pow # Returns nothing because can't pow a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").equal(edge().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") == edge().attribute("lorem") ** "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").not_equal(edge().attribute("lorem").pow("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") != edge().attribute("lorem") ** "10" - ), - ) + assert medrecord.select_edges(edge().attribute("lorem").equal(edge().attribute("lorem").pow("10"))) == [] + assert medrecord.select_edges(edge().attribute("lorem") == edge().attribute("lorem") ** "10") == [] + assert medrecord.select_edges(edge().attribute("lorem").not_equal(edge().attribute("lorem").pow("10"))) == [] + assert medrecord.select_edges(edge().attribute("lorem") != edge().attribute("lorem") ** "10") == [] # Attribute compare to attribute pow - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").pow(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").pow(2)) - ), - ) + assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").pow(2))) == [2] + assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").pow(2))) == [] # Attribute compare to attribute mod # Returns nothing because can't mod a string - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").equal(edge().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") == edge().attribute("lorem") % "10" - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem").not_equal(edge().attribute("lorem").mod("10")) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("lorem") != edge().attribute("lorem") % "10" - ), - ) + assert medrecord.select_edges(edge().attribute("lorem").equal(edge().attribute("lorem").mod("10"))) == [] + assert medrecord.select_edges(edge().attribute("lorem") == edge().attribute("lorem") % "10") == [] + assert medrecord.select_edges(edge().attribute("lorem").not_equal(edge().attribute("lorem").mod("10"))) == [] + assert medrecord.select_edges(edge().attribute("lorem") != edge().attribute("lorem") % "10") == [] # Attribute compare to attribute mod - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").mod(2)) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").mod(2)) - ), - ) + assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").mod(2))) == [2] + assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").mod(2))) == [] # Attribute compare to attribute round - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("sed").round()) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").not_equal(edge().attribute("sed").round()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("float").round()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").round()) - ), - ) + assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed").round())) == [0] + assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed").round())) == [] + assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("float").round())) == [2] + assert medrecord.select_edges(edge().attribute("float").not_equal(edge().attribute("float").round())) == [2] # Attribute compare to attribute ceil - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("float").ceil()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").ceil()) - ), - ) + assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("float").ceil())) == [2] + assert medrecord.select_edges(edge().attribute("float").not_equal(edge().attribute("float").ceil())) == [2] # Attribute compare to attribute floor - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("float").floor()) - ), - ) - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("float").not_equal(edge().attribute("float").floor()) - ), - ) + assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("float").floor())) == [] + assert medrecord.select_edges(edge().attribute("float").not_equal(edge().attribute("float").floor())) == [2] # Attribute compare to attribute abs - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").abs()) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("integer").not_equal(edge().attribute("integer").abs()) - ), - ) + assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").abs())) == [2] + assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").abs())) == [] # Attribute compare to attribute sqrt - self.assertEqual( - [2], - medrecord.select_edges( - edge().attribute("integer").equal(edge().attribute("integer").sqrt()) - ), - ) - self.assertEqual( - [], - medrecord.select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").sqrt()) - ), - ) + assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").sqrt())) == [2] + assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").sqrt())) == [] # Attribute compare to attribute trim - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("dolor").trim()) - ), - ) + assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("dolor").trim())) == [0] # Attribute compare to attribute trim_start - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("dolor").trim_start()) - ), - ) + assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("dolor").trim_start())) == [] # Attribute compare to attribute trim_end - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("dolor").trim_end()) - ), - ) + assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("dolor").trim_end())) == [] # Attribute compare to attribute lowercase - self.assertEqual( - [0], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("test").lowercase()) - ), - ) + assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("test").lowercase())) == [0] # Attribute compare to attribute uppercase - self.assertEqual( - [], - medrecord.select_edges( - edge().attribute("sed").equal(edge().attribute("test").uppercase()) - ), - ) + assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("test").uppercase())) == [] diff --git a/medmodels/medrecord/tests/test_schema.py b/medmodels/medrecord/tests/test_schema.py index f0dc98fd..3dfaa342 100644 --- a/medmodels/medrecord/tests/test_schema.py +++ b/medmodels/medrecord/tests/test_schema.py @@ -1,5 +1,7 @@ import unittest +import pytest + import medmodels.medrecord as mr from medmodels._medmodels import PyAttributeType from medmodels.medrecord.schema import GroupSchema, Schema @@ -10,69 +12,44 @@ def create_medrecord() -> mr.MedRecord: class TestSchema(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.schema = create_medrecord().schema - def test_groups(self): - self.assertEqual( - sorted( - [ - "diagnosis", - "drug", - "patient_diagnosis", - "patient_drug", - "patient_procedure", - "patient", - "procedure", - ] - ), - sorted(self.schema.groups), - ) + def test_groups(self) -> None: + assert sorted(["diagnosis", "drug", "patient_diagnosis", "patient_drug", "patient_procedure", "patient", "procedure"]) == sorted(self.schema.groups) - def test_group(self): - self.assertTrue(isinstance(self.schema.group("patient"), mr.GroupSchema)) # pyright: ignore[reportUnnecessaryIsInstance] + def test_group(self) -> None: + assert isinstance(self.schema.group("patient"), mr.GroupSchema) # pyright: ignore[reportUnnecessaryIsInstance] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.schema.group("nonexistent") - def test_default(self): - self.assertEqual(None, self.schema.default) + def test_default(self) -> None: + assert None is self.schema.default schema = Schema(default=GroupSchema(nodes={"description": mr.String()})) - self.assertTrue(isinstance(schema.default, mr.GroupSchema)) + assert isinstance(schema.default, mr.GroupSchema) - def test_strict(self): - self.assertEqual(True, self.schema.strict) + def test_strict(self) -> None: + assert True is self.schema.strict class TestGroupSchema(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.schema = create_medrecord().schema - def test_nodes(self): - self.assertEqual( - { - "age": (mr.Int(), mr.AttributeType.Continuous), - "gender": (mr.String(), mr.AttributeType.Categorical), - }, - self.schema.group("patient").nodes, - ) + def test_nodes(self) -> None: + assert self.schema.group("patient").nodes == {"age": (mr.Int(), mr.AttributeType.Continuous), "gender": (mr.String(), mr.AttributeType.Categorical)} - def test_edges(self): - self.assertEqual( - { - "diagnosis_time": (mr.DateTime(), mr.AttributeType.Temporal), - "duration_days": (mr.Option(mr.Float()), mr.AttributeType.Continuous), - }, - self.schema.group("patient_diagnosis").edges, - ) + def test_edges(self) -> None: + assert self.schema.group("patient_diagnosis").edges == {"diagnosis_time": (mr.DateTime(), mr.AttributeType.Temporal), "duration_days": (mr.Option(mr.Float()), mr.AttributeType.Continuous)} - def test_strict(self): - self.assertEqual(True, self.schema.group("patient").strict) + def test_strict(self) -> None: + assert True is self.schema.group("patient").strict class TestAttributesSchema(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.attributes_schema = ( Schema( groups={"diagnosis": GroupSchema(nodes={"description": mr.String()})}, @@ -82,11 +59,8 @@ def setUp(self): .nodes ) - def test_repr(self): - self.assertEqual( - "{'description': (DataType.String, None)}", - repr(self.attributes_schema), - ) + def test_repr(self) -> None: + assert repr(self.attributes_schema) == "{'description': (DataType.String, None)}" second_attributes_schema = ( Schema( @@ -103,28 +77,22 @@ def test_repr(self): .nodes ) - self.assertEqual( - "{'description': (DataType.String, AttributeType.Categorical)}", - repr(second_attributes_schema), - ) + assert repr(second_attributes_schema) == "{'description': (DataType.String, AttributeType.Categorical)}" - def test_getitem(self): - self.assertEqual( - (mr.String(), None), - self.attributes_schema["description"], - ) + def test_getitem(self) -> None: + assert (mr.String(), None) == self.attributes_schema["description"] - with self.assertRaises(KeyError): + with pytest.raises(KeyError): self.attributes_schema["nonexistent"] - def test_contains(self): - self.assertTrue("description" in self.attributes_schema) - self.assertFalse("nonexistent" in self.attributes_schema) + def test_contains(self) -> None: + assert "description" in self.attributes_schema + assert "nonexistent" not in self.attributes_schema - def test_len(self): - self.assertEqual(1, len(self.attributes_schema)) + def test_len(self) -> None: + assert len(self.attributes_schema) == 1 - def test_eq(self): + def test_eq(self) -> None: comparison_attributes_schema = ( Schema( groups={"diagnosis": GroupSchema(nodes={"description": mr.String()})}, @@ -134,7 +102,7 @@ def test_eq(self): .nodes ) - self.assertEqual(self.attributes_schema, comparison_attributes_schema) + assert self.attributes_schema == comparison_attributes_schema comparison_attributes_schema = ( Schema( @@ -145,7 +113,7 @@ def test_eq(self): .nodes ) - self.assertNotEqual(self.attributes_schema, comparison_attributes_schema) + assert self.attributes_schema != comparison_attributes_schema comparison_attributes_schema = ( Schema( @@ -162,7 +130,7 @@ def test_eq(self): .nodes ) - self.assertNotEqual(self.attributes_schema, comparison_attributes_schema) + assert self.attributes_schema != comparison_attributes_schema comparison_attributes_schema = ( Schema( @@ -179,60 +147,45 @@ def test_eq(self): .nodes ) - self.assertNotEqual(self.attributes_schema, comparison_attributes_schema) + assert self.attributes_schema != comparison_attributes_schema - self.assertNotEqual(self.attributes_schema, None) + assert self.attributes_schema is not None - def test_keys(self): - self.assertEqual(["description"], list(self.attributes_schema.keys())) + def test_keys(self) -> None: + assert list(self.attributes_schema.keys()) == ["description"] - def test_values(self): - self.assertEqual( - [(mr.String(), None)], - list(self.attributes_schema.values()), - ) + def test_values(self) -> None: + assert [(mr.String(), None)] == list(self.attributes_schema.values()) - def test_items(self): - self.assertEqual( - [("description", (mr.String(), None))], - list(self.attributes_schema.items()), - ) + def test_items(self) -> None: + assert [("description", (mr.String(), None))] == list(self.attributes_schema.items()) - def test_get(self): - self.assertEqual( - (mr.String(), None), - self.attributes_schema.get("description"), - ) + def test_get(self) -> None: + assert (mr.String(), None) == self.attributes_schema.get("description") - self.assertEqual( - None, - self.attributes_schema.get("nonexistent"), - ) + assert None is self.attributes_schema.get("nonexistent") - self.assertEqual( - (mr.String(), None), - self.attributes_schema.get("nonexistent", (mr.String(), None)), - ) + assert (mr.String(), None) == self.attributes_schema.get("nonexistent", (mr.String(), None)) class TestAttributeType(unittest.TestCase): - def test_str(self): - self.assertEqual("Categorical", str(mr.AttributeType.Categorical)) - self.assertEqual("Continuous", str(mr.AttributeType.Continuous)) - self.assertEqual("Temporal", str(mr.AttributeType.Temporal)) - - def test_eq(self): - self.assertEqual(mr.AttributeType.Categorical, mr.AttributeType.Categorical) - self.assertEqual(mr.AttributeType.Categorical, PyAttributeType.Categorical) - self.assertNotEqual(mr.AttributeType.Categorical, mr.AttributeType.Continuous) - self.assertNotEqual(mr.AttributeType.Categorical, PyAttributeType.Continuous) - - self.assertEqual(mr.AttributeType.Continuous, mr.AttributeType.Continuous) - self.assertEqual(mr.AttributeType.Continuous, PyAttributeType.Continuous) - self.assertNotEqual(mr.AttributeType.Continuous, mr.AttributeType.Categorical) - self.assertNotEqual(mr.AttributeType.Continuous, PyAttributeType.Categorical) - - self.assertEqual(mr.AttributeType.Temporal, mr.AttributeType.Temporal) - self.assertEqual(mr.AttributeType.Temporal, PyAttributeType.Temporal) - self.assertNotEqual(mr.AttributeType.Temporal, mr.AttributeType.Categorical) - self.assertNotEqual(mr.AttributeType.Temporal, PyAttributeType.Categorical) + def test_str(self) -> None: + assert str(mr.AttributeType.Categorical) == "Categorical" + assert str(mr.AttributeType.Continuous) == "Continuous" + assert str(mr.AttributeType.Temporal) == "Temporal" + + def test_eq(self) -> None: + assert mr.AttributeType.Categorical == mr.AttributeType.Categorical + assert mr.AttributeType.Categorical == PyAttributeType.Categorical + assert mr.AttributeType.Categorical != mr.AttributeType.Continuous + assert mr.AttributeType.Categorical != PyAttributeType.Continuous + + assert mr.AttributeType.Continuous == mr.AttributeType.Continuous + assert mr.AttributeType.Continuous == PyAttributeType.Continuous + assert mr.AttributeType.Continuous != mr.AttributeType.Categorical + assert mr.AttributeType.Continuous != PyAttributeType.Categorical + + assert mr.AttributeType.Temporal == mr.AttributeType.Temporal + assert mr.AttributeType.Temporal == PyAttributeType.Temporal + assert mr.AttributeType.Temporal != mr.AttributeType.Categorical + assert mr.AttributeType.Temporal != PyAttributeType.Categorical diff --git a/medmodels/treatment_effect/builder.py b/medmodels/treatment_effect/builder.py index 9a6e5be6..8615c744 100644 --- a/medmodels/treatment_effect/builder.py +++ b/medmodels/treatment_effect/builder.py @@ -1,16 +1,18 @@ from __future__ import annotations -from typing import Any, Dict, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional import medmodels.treatment_effect.treatment_effect as tee -from medmodels.medrecord.querying import NodeOperation -from medmodels.medrecord.types import ( - Group, - MedRecordAttribute, - MedRecordAttributeInputList, -) -from medmodels.treatment_effect.matching.algorithms.propensity_score import Model -from medmodels.treatment_effect.matching.matching import MatchingMethod + +if TYPE_CHECKING: + from medmodels.medrecord.querying import NodeOperation + from medmodels.medrecord.types import ( + Group, + MedRecordAttribute, + MedRecordAttributeInputList, + ) + from medmodels.treatment_effect.matching.algorithms.propensity_score import Model + from medmodels.treatment_effect.matching.matching import MatchingMethod class TreatmentEffectBuilder: @@ -218,8 +220,8 @@ def filter_controls(self, operation: NodeOperation) -> TreatmentEffectBuilder: def with_propensity_matching( self, - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + essential_covariates: MedRecordAttributeInputList = None, + one_hot_covariates: MedRecordAttributeInputList = None, model: Model = "logit", number_of_neighbors: int = 1, hyperparam: Optional[Dict[str, Any]] = None, @@ -244,6 +246,10 @@ def with_propensity_matching( TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated matching configurations. """ + if one_hot_covariates is None: + one_hot_covariates = ["gender"] + if essential_covariates is None: + essential_covariates = ["gender", "age"] self.matching_method = "propensity" self.matching_essential_covariates = essential_covariates self.matching_one_hot_covariates = one_hot_covariates @@ -255,8 +261,8 @@ def with_propensity_matching( def with_nearest_neighbors_matching( self, - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + essential_covariates: MedRecordAttributeInputList = None, + one_hot_covariates: MedRecordAttributeInputList = None, number_of_neighbors: int = 1, ) -> TreatmentEffectBuilder: """Adjust the treatment effect estimate using nearest neighbors matching. @@ -275,6 +281,10 @@ def with_nearest_neighbors_matching( TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder with updated matching configurations. """ + if one_hot_covariates is None: + one_hot_covariates = ["gender"] + if essential_covariates is None: + essential_covariates = ["gender", "age"] self.matching_method = "nearest_neighbors" self.matching_essential_covariates = essential_covariates self.matching_one_hot_covariates = one_hot_covariates diff --git a/medmodels/treatment_effect/continuous_estimators.py b/medmodels/treatment_effect/continuous_estimators.py index f0310616..09388ccf 100644 --- a/medmodels/treatment_effect/continuous_estimators.py +++ b/medmodels/treatment_effect/continuous_estimators.py @@ -74,7 +74,8 @@ def average_treatment_effect( ] ) if not all(isinstance(i, (int, float)) for i in treated_outcomes): - raise ValueError("Outcome variable must be numeric") + msg = "Outcome variable must be numeric" + raise ValueError(msg) control_outcomes = np.array( [ @@ -91,7 +92,8 @@ def average_treatment_effect( ] ) if not all(isinstance(i, (int, float)) for i in control_outcomes): - raise ValueError("Outcome variable must be numeric") + msg = "Outcome variable must be numeric" + raise ValueError(msg) return treated_outcomes.mean() - control_outcomes.mean() @@ -168,7 +170,8 @@ def cohens_d( ] ) if not all(isinstance(i, (int, float)) for i in treated_outcomes): - raise ValueError("Outcome variable must be numeric") + msg = "Outcome variable must be numeric" + raise ValueError(msg) control_outcomes = np.array( [ @@ -185,7 +188,8 @@ def cohens_d( ] ) if not all(isinstance(i, (int, float)) for i in control_outcomes): - raise ValueError("Outcome variable must be numeric") + msg = "Outcome variable must be numeric" + raise ValueError(msg) min_len = min(len(treated_outcomes), len(control_outcomes)) cf = 1 # correction factor diff --git a/medmodels/treatment_effect/estimate.py b/medmodels/treatment_effect/estimate.py index 09686e5d..e9ddb915 100644 --- a/medmodels/treatment_effect/estimate.py +++ b/medmodels/treatment_effect/estimate.py @@ -2,17 +2,17 @@ from typing import TYPE_CHECKING, Dict, Literal, Set, Tuple -from medmodels.medrecord.medrecord import MedRecord -from medmodels.medrecord.types import MedRecordAttribute, NodeIndex from medmodels.treatment_effect.continuous_estimators import ( average_treatment_effect, cohens_d, ) -from medmodels.treatment_effect.matching.matching import Matching from medmodels.treatment_effect.matching.neighbors import NeighborsMatching from medmodels.treatment_effect.matching.propensity import PropensityMatching if TYPE_CHECKING: + from medmodels.medrecord.medrecord import MedRecord + from medmodels.medrecord.types import MedRecordAttribute, NodeIndex + from medmodels.treatment_effect.matching.matching import Matching from medmodels.treatment_effect.treatment_effect import TreatmentEffect @@ -33,20 +33,29 @@ def _check_medrecord(self, medrecord: MedRecord) -> None: MedRecord (patients, treatments, outcomes). """ if self._treatment_effect._patients_group not in medrecord.groups: - raise ValueError( + msg = ( f"Patient group {self._treatment_effect._patients_group} not found in " f"the MedRecord. Available groups: {medrecord.groups}" ) - if self._treatment_effect._treatments_group not in medrecord.groups: raise ValueError( + msg + ) + if self._treatment_effect._treatments_group not in medrecord.groups: + msg = ( "Treatment group not found in the MedRecord. " f"Available groups: {medrecord.groups}" ) - if self._treatment_effect._outcomes_group not in medrecord.groups: raise ValueError( + msg + ) + if self._treatment_effect._outcomes_group not in medrecord.groups: + msg = ( "Outcome group not found in the MedRecord." f"Available groups: {medrecord.groups}" ) + raise ValueError( + msg + ) def _sort_subjects_in_contingency_table( self, medrecord: MedRecord @@ -127,11 +136,14 @@ def _compute_subject_counts( ) if len(treatment_false) == 0: - raise ValueError("No subjects found in the treatment false group") + msg = "No subjects found in the treatment false group" + raise ValueError(msg) if len(control_true) == 0: - raise ValueError("No subjects found in the control true group") + msg = "No subjects found in the control true group" + raise ValueError(msg) if len(control_false) == 0: - raise ValueError("No subjects found in the control false group") + msg = "No subjects found in the control false group" + raise ValueError(msg) return ( len(treatment_true), @@ -160,15 +172,13 @@ def subjects_contingency_table( self._sort_subjects_in_contingency_table(medrecord=medrecord) ) - subjects = { + return { "treatment_true": treatment_true, "treatment_false": treatment_false, "control_true": control_true, "control_false": control_false, } - return subjects - def subject_counts(self, medrecord: MedRecord) -> Dict[str, int]: """Returns the subject counts for the treatment and control groups in a Dictionary. @@ -193,15 +203,13 @@ def subject_counts(self, medrecord: MedRecord) -> Dict[str, int]: num_control_false, ) = self._compute_subject_counts(medrecord=medrecord) - subject_counts = { + return { "treatment_true": num_treat_true, "treatment_false": num_treat_false, "control_true": num_control_true, "control_false": num_control_false, } - return subject_counts - def relative_risk(self, medrecord: MedRecord) -> float: """Calculates the relative risk (RR) of an event occurring in the treatment group compared to the control group. @@ -371,7 +379,8 @@ def number_needed_to_treat(self, medrecord: MedRecord) -> float: """ ar = self.absolute_risk(medrecord) if ar == 0: - raise ValueError("Absolute risk is zero, cannot calculate NNT.") + msg = "Absolute risk is zero, cannot calculate NNT." + raise ValueError(msg) return 1 / ar def hazard_ratio(self, medrecord: MedRecord) -> float: @@ -404,8 +413,9 @@ def hazard_ratio(self, medrecord: MedRecord) -> float: hazard_control = num_control_true / (num_control_true + num_control_false) if hazard_control == 0: + msg = "Control hazard rate is zero, cannot calculate hazard ratio." raise ValueError( - "Control hazard rate is zero, cannot calculate hazard ratio." + msg ) return hazard_treat / hazard_control diff --git a/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py b/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py index 9d6e9141..3c0dd961 100644 --- a/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py +++ b/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py @@ -35,8 +35,9 @@ def nearest_neighbor( unit. """ if treated_set.shape[0] * number_of_neighbors > control_set.shape[0]: + msg = "The treated set is too large for the given number of neighbors." raise ValueError( - "The treated set is too large for the given number of neighbors." + msg ) if not covariates: covariates = treated_set.columns diff --git a/medmodels/treatment_effect/matching/algorithms/propensity_score.py b/medmodels/treatment_effect/matching/algorithms/propensity_score.py index 92ca7c4b..221c2fad 100644 --- a/medmodels/treatment_effect/matching/algorithms/propensity_score.py +++ b/medmodels/treatment_effect/matching/algorithms/propensity_score.py @@ -4,12 +4,10 @@ import numpy as np import polars as pl -from numpy.typing import NDArray from sklearn.ensemble import RandomForestClassifier from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier -from medmodels.medrecord.types import MedRecordAttributeInputList from medmodels.treatment_effect.matching.algorithms.classic_distance_models import ( nearest_neighbor, ) @@ -17,6 +15,10 @@ if TYPE_CHECKING: import sys + from numpy.typing import NDArray + + from medmodels.medrecord.types import MedRecordAttributeInputList + if sys.version_info >= (3, 10): from typing import TypeAlias else: diff --git a/medmodels/treatment_effect/matching/matching.py b/medmodels/treatment_effect/matching/matching.py index a396970f..184e97ee 100644 --- a/medmodels/treatment_effect/matching/matching.py +++ b/medmodels/treatment_effect/matching/matching.py @@ -5,12 +5,12 @@ import polars as pl -from medmodels.medrecord.medrecord import MedRecord -from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex - if TYPE_CHECKING: import sys + from medmodels.medrecord.medrecord import MedRecord + from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex + if sys.version_info >= (3, 10): from typing import TypeAlias else: diff --git a/medmodels/treatment_effect/matching/neighbors.py b/medmodels/treatment_effect/matching/neighbors.py index 5fb71e9d..f0d14bd1 100644 --- a/medmodels/treatment_effect/matching/neighbors.py +++ b/medmodels/treatment_effect/matching/neighbors.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import Set +from typing import TYPE_CHECKING, Set -from medmodels import MedRecord -from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex from medmodels.treatment_effect.matching.algorithms.classic_distance_models import ( nearest_neighbor, ) from medmodels.treatment_effect.matching.matching import Matching +if TYPE_CHECKING: + from medmodels import MedRecord + from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex + class NeighborsMatching(Matching): """Class for the nearest neighbor matching. @@ -25,7 +27,7 @@ def __init__( self, *, number_of_neighbors: int = 1, - ): + ) -> None: """Initializes the nearest neighbors class. Args: @@ -40,8 +42,8 @@ def match_controls( medrecord: MedRecord, control_group: Set[NodeIndex], treated_group: Set[NodeIndex], - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + essential_covariates: MedRecordAttributeInputList = None, + one_hot_covariates: MedRecordAttributeInputList = None, ) -> Set[NodeIndex]: """Matches the controls based on the nearest neighbor algorithm. @@ -57,6 +59,10 @@ def match_controls( Returns: Set[NodeIndex]: Node Ids of the matched controls. """ + if one_hot_covariates is None: + one_hot_covariates = ["gender"] + if essential_covariates is None: + essential_covariates = ["gender", "age"] data_treated, data_control = self._preprocess_data( medrecord=medrecord, control_group=control_group, diff --git a/medmodels/treatment_effect/matching/propensity.py b/medmodels/treatment_effect/matching/propensity.py index 3ce6cfa6..4625ecbc 100644 --- a/medmodels/treatment_effect/matching/propensity.py +++ b/medmodels/treatment_effect/matching/propensity.py @@ -1,12 +1,10 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, Optional, Set import numpy as np import polars as pl -from medmodels import MedRecord -from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex from medmodels.treatment_effect.matching.algorithms.classic_distance_models import ( nearest_neighbor, ) @@ -16,6 +14,10 @@ ) from medmodels.treatment_effect.matching.matching import Matching +if TYPE_CHECKING: + from medmodels import MedRecord + from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex + class PropensityMatching(Matching): """Class for the propensity score matching. @@ -37,7 +39,7 @@ def __init__( model: Model = "logit", number_of_neighbors: int = 1, hyperparam: Optional[Dict[str, Any]] = None, - ): + ) -> None: """Initializes the propensity score class. Args: @@ -60,8 +62,8 @@ def match_controls( medrecord: MedRecord, control_group: Set[NodeIndex], treated_group: Set[NodeIndex], - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + essential_covariates: MedRecordAttributeInputList = None, + one_hot_covariates: MedRecordAttributeInputList = None, ) -> Set[NodeIndex]: """Matches the controls based on propensity score matching. @@ -78,6 +80,10 @@ def match_controls( Set[NodeIndex]: Node Ids of the matched controls. """ # Preprocess the data + if one_hot_covariates is None: + one_hot_covariates = ["gender"] + if essential_covariates is None: + essential_covariates = ["gender", "age"] data_treated, data_control = self._preprocess_data( medrecord=medrecord, treated_group=treated_group, diff --git a/medmodels/treatment_effect/matching/tests/test_classic_distance_models.py b/medmodels/treatment_effect/matching/tests/test_classic_distance_models.py index ba2e6dc7..9243a08a 100644 --- a/medmodels/treatment_effect/matching/tests/test_classic_distance_models.py +++ b/medmodels/treatment_effect/matching/tests/test_classic_distance_models.py @@ -2,6 +2,7 @@ import numpy as np import polars as pl +import pytest from medmodels.treatment_effect.matching.algorithms import ( classic_distance_models as cdm, @@ -9,7 +10,7 @@ class TestClassicDistanceModels(unittest.TestCase): - def test_nearest_neighbor(self): + def test_nearest_neighbor(self) -> None: ########################################### # 1D example c_set = pl.DataFrame({"a": [1, 5, 1, 3]}) @@ -17,7 +18,7 @@ def test_nearest_neighbor(self): expected_result = pl.DataFrame({"a": [1.0, 5.0]}) result = cdm.nearest_neighbor(t_set, c_set) - self.assertTrue(result.equals(expected_result)) + assert result.equals(expected_result) ########################################### # 3D example with covariates @@ -29,7 +30,7 @@ def test_nearest_neighbor(self): expected_result = pl.DataFrame([[1.0, 3.0, 5.0]], schema=cols, orient="row") result = cdm.nearest_neighbor(t_set, c_set, covariates=covs) - self.assertTrue(result.equals(expected_result)) + assert result.equals(expected_result) # 2 nearest neighbors expected_abs_2nn = pl.DataFrame( @@ -38,21 +39,19 @@ def test_nearest_neighbor(self): result_abs_2nn = cdm.nearest_neighbor( t_set, c_set, covariates=covs, number_of_neighbors=2 ) - self.assertTrue(result_abs_2nn.equals(expected_abs_2nn)) + assert result_abs_2nn.equals(expected_abs_2nn) - def test_nearest_neighbor_value_error(self): - # Test case for checking the ValueError when all control units have been matched + def test_nearest_neighbor_value_error(self) -> None: + """Checking the ValueError when all control units have been matched.""" c_set = pl.DataFrame({"a": [1, 2]}) t_set = pl.DataFrame({"a": [1, 2, 3]}) - with self.assertRaises(ValueError) as context: + with pytest.raises( + ValueError, + match="The treated set is too large for the given number of neighbors.", + ): cdm.nearest_neighbor(t_set, c_set, number_of_neighbors=2) - self.assertEqual( - str(context.exception), - "The treated set is too large for the given number of neighbors.", - ) - if __name__ == "__main__": run_test = unittest.TestLoader().loadTestsFromTestCase(TestClassicDistanceModels) diff --git a/medmodels/treatment_effect/matching/tests/test_metrics.py b/medmodels/treatment_effect/matching/tests/test_metrics.py index 80c16f34..81074f0f 100644 --- a/medmodels/treatment_effect/matching/tests/test_metrics.py +++ b/medmodels/treatment_effect/matching/tests/test_metrics.py @@ -6,21 +6,17 @@ class TestMetrics(unittest.TestCase): - def test_absolute_metric(self): - self.assertEqual(metrics.absolute_metric(np.array([-2]), np.array([-1])), 1) - self.assertEqual( - metrics.absolute_metric(np.array([2, -1]), np.array([1, 3])), 5 - ) + def test_absolute_metric(self) -> None: + assert metrics.absolute_metric(np.array([-2]), np.array([-1])) == 1 + assert metrics.absolute_metric(np.array([2, -1]), np.array([1, 3])) == 5 - def test_exact_metric(self): - self.assertEqual(metrics.exact_metric(np.array([-2]), np.array([-2])), 0) - self.assertEqual(metrics.exact_metric(np.array([-2]), np.array([-1])), np.inf) - self.assertEqual(metrics.exact_metric(np.array([2, -1]), np.array([2, -1])), 0) - self.assertEqual( - metrics.exact_metric(np.array([2, -1]), np.array([2, 1])), np.inf - ) + def test_exact_metric(self) -> None: + assert metrics.exact_metric(np.array([-2]), np.array([-2])) == 0 + assert metrics.exact_metric(np.array([-2]), np.array([-1])) == np.inf + assert metrics.exact_metric(np.array([2, -1]), np.array([2, -1])) == 0 + assert metrics.exact_metric(np.array([2, -1]), np.array([2, 1])) == np.inf - def test_mahalanobis_metric(self): + def test_mahalanobis_metric(self) -> None: data = np.array( [[64, 66, 68, 69, 73], [580, 570, 590, 660, 600], [29, 33, 37, 46, 55]] ) diff --git a/medmodels/treatment_effect/matching/tests/test_propensity_score.py b/medmodels/treatment_effect/matching/tests/test_propensity_score.py index e505d6ab..af6a6400 100644 --- a/medmodels/treatment_effect/matching/tests/test_propensity_score.py +++ b/medmodels/treatment_effect/matching/tests/test_propensity_score.py @@ -8,7 +8,7 @@ class TestPropensityScore(unittest.TestCase): - def test_calculate_propensity(self): + def test_calculate_propensity(self) -> None: x, y = load_iris(return_X_y=True) # Set random state by each propensity estimator: @@ -48,7 +48,7 @@ def test_calculate_propensity(self): self.assertAlmostEqual(result_1[0], 0, places=2) self.assertAlmostEqual(result_2[0], 0, places=2) - def test_run_propensity_score(self): + def test_run_propensity_score(self) -> None: # Set random state by each propensity estimator: hyperparam = {"random_state": 1} hyperparam_logit = {"random_state": 1, "max_iter": 200} @@ -63,21 +63,21 @@ def test_run_propensity_score(self): result_logit = ps.run_propensity_score( treated_set, control_set, hyperparam=hyperparam_logit ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) # dec_tree metric expected_logit = pl.DataFrame({"a": [1.0, 1.0]}) result_logit = ps.run_propensity_score( treated_set, control_set, model="dec_tree", hyperparam=hyperparam ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) # forest model expected_logit = pl.DataFrame({"a": [1.0, 1.0]}) result_logit = ps.run_propensity_score( treated_set, control_set, model="forest", hyperparam=hyperparam ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) ########################################### # 3D example with covariates @@ -92,7 +92,7 @@ def test_run_propensity_score(self): result_logit = ps.run_propensity_score( treated_set, control_set, covariates=covs, hyperparam=hyperparam_logit ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) # dec_tree model expected_logit = pl.DataFrame({"a": [1.0], "b": [3.0], "c": [5.0]}) @@ -103,7 +103,7 @@ def test_run_propensity_score(self): covariates=covs, hyperparam=hyperparam, ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) # forest model expected_logit = pl.DataFrame({"a": [1.0], "b": [3.0], "c": [5.0]}) @@ -114,7 +114,7 @@ def test_run_propensity_score(self): covariates=covs, hyperparam=hyperparam, ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) # using 2 nearest neighbors expected_logit = pl.DataFrame( @@ -125,7 +125,7 @@ def test_run_propensity_score(self): control_set, number_of_neighbors=2, ) - self.assertTrue(result_logit.equals(expected_logit)) + assert result_logit.equals(expected_logit) if __name__ == "__main__": diff --git a/medmodels/treatment_effect/report.py b/medmodels/treatment_effect/report.py index 883c6d23..c999f834 100644 --- a/medmodels/treatment_effect/report.py +++ b/medmodels/treatment_effect/report.py @@ -2,10 +2,9 @@ from typing import TYPE_CHECKING, Literal, TypedDict -from medmodels.medrecord.medrecord import MedRecord -from medmodels.medrecord.types import MedRecordAttribute - if TYPE_CHECKING: + from medmodels.medrecord.medrecord import MedRecord + from medmodels.medrecord.types import MedRecordAttribute from medmodels.treatment_effect.treatment_effect import TreatmentEffect diff --git a/medmodels/treatment_effect/temporal_analysis.py b/medmodels/treatment_effect/temporal_analysis.py index f87d79d8..f8787e55 100644 --- a/medmodels/treatment_effect/temporal_analysis.py +++ b/medmodels/treatment_effect/temporal_analysis.py @@ -77,7 +77,8 @@ def find_reference_edge( edge_values = medrecord.edge[edges].values() if not all(time_attribute in edge_attribute for edge_attribute in edge_values): - raise ValueError("Time attribute not found in the edge attributes") + msg = "Time attribute not found in the edge attributes" + raise ValueError(msg) for edge in edges: edge_time = pd.to_datetime(str(medrecord.edge[edge][time_attribute])) @@ -89,7 +90,8 @@ def find_reference_edge( reference_time = edge_time if reference_edge is None: - raise ValueError(f"No edge found for node {node_index} in this MedRecord") + msg = f"No edge found for node {node_index} in this MedRecord" + raise ValueError(msg) return reference_edge @@ -171,7 +173,8 @@ def find_node_in_time_window( for edge in edges: edge_attributes = medrecord.edge[edge] if time_attribute not in edge_attributes: - raise ValueError("Time attribute not found in the edge attributes") + msg = "Time attribute not found in the edge attributes" + raise ValueError(msg) event_time = pd.to_datetime(str(edge_attributes[time_attribute])) time_difference = event_time - reference_time diff --git a/medmodels/treatment_effect/tests/test_continuous_estimators.py b/medmodels/treatment_effect/tests/test_continuous_estimators.py index 70db8809..1a067a4f 100644 --- a/medmodels/treatment_effect/tests/test_continuous_estimators.py +++ b/medmodels/treatment_effect/tests/test_continuous_estimators.py @@ -1,9 +1,10 @@ """Tests for the TreatmentEffect class in the treatment_effect module.""" import unittest -from typing import List +from typing import List, Optional import pandas as pd +import pytest from medmodels import MedRecord from medmodels.medrecord.types import NodeIndex @@ -37,8 +38,7 @@ def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame: } ) - patients = patients.loc[patients["index"].isin(patient_list)] - return patients + return patients.loc[patients["index"].isin(patient_list)] def create_diagnoses() -> pd.DataFrame: @@ -47,13 +47,12 @@ def create_diagnoses() -> pd.DataFrame: Returns: pd.DataFrame: A diagnoses dataframe. """ - diagnoses = pd.DataFrame( + return pd.DataFrame( { "index": ["D1"], "name": ["Stroke"], } ) - return diagnoses def create_prescriptions() -> pd.DataFrame: @@ -62,13 +61,12 @@ def create_prescriptions() -> pd.DataFrame: Returns: pd.DataFrame: A prescriptions dataframe. """ - prescriptions = pd.DataFrame( + return pd.DataFrame( { "index": ["M1", "M2"], "name": ["Rivaroxaban", "Warfarin"], } ) - return prescriptions def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -108,8 +106,7 @@ def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: ], } ) - edges = edges.loc[edges["target"].isin(patient_list)] - return edges + return edges.loc[edges["target"].isin(patient_list)] def create_edges2(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -162,28 +159,19 @@ def create_edges2(patient_list: List[NodeIndex]) -> pd.DataFrame: ], } ) - edges = edges.loc[edges["target"].isin(patient_list)] - return edges + return edges.loc[edges["target"].isin(patient_list)] def create_medrecord( - patient_list: List[NodeIndex] = [ - "P1", - "P2", - "P3", - "P4", - "P5", - "P6", - "P7", - "P8", - "P9", - ], + patient_list: Optional[List[NodeIndex]] = None, ) -> MedRecord: """Creates a MedRecord object. Returns: MedRecord: A MedRecord object. """ + if patient_list is None: + patient_list = ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9"] patients = create_patients(patient_list=patient_list) diagnoses = create_diagnoses() prescriptions = create_prescriptions() @@ -213,12 +201,12 @@ def create_medrecord( class TestContinuousEstimators(unittest.TestCase): """Class to test the continuous estimators.""" - def setUp(self): + def setUp(self) -> None: self.medrecord = create_medrecord() self.outcome_group = "Stroke" self.time_attribute = "time" - def test_average_treatment_effect(self): + def test_average_treatment_effect(self) -> None: ate_result = average_treatment_effect( self.medrecord, treatment_true_set=set({"P2", "P3"}), @@ -241,8 +229,8 @@ def test_average_treatment_effect(self): ) self.assertAlmostEqual(-0.15, ate_result) - def test_invalid_treatment_effect(self): - with self.assertRaisesRegex(ValueError, "Outcome variable must be numeric"): + def test_invalid_treatment_effect(self) -> None: + with pytest.raises(ValueError, match="Outcome variable must be numeric"): average_treatment_effect( self.medrecord, treatment_true_set=set({"P2", "P3"}), @@ -253,7 +241,7 @@ def test_invalid_treatment_effect(self): time_attribute=self.time_attribute, ) - def test_cohens_d(self): + def test_cohens_d(self) -> None: cohens_d_result = cohens_d( self.medrecord, treatment_true_set=set({"P2", "P3"}), @@ -288,8 +276,8 @@ def test_cohens_d(self): ) self.assertAlmostEqual(0, cohens_d_corrected) - def test_invalid_cohens_D(self): - with self.assertRaisesRegex(ValueError, "Outcome variable must be numeric"): + def test_invalid_cohens_D(self) -> None: + with pytest.raises(ValueError, match="Outcome variable must be numeric"): cohens_d( self.medrecord, treatment_true_set=set({"P2", "P3"}), diff --git a/medmodels/treatment_effect/tests/test_temporal_analysis.py b/medmodels/treatment_effect/tests/test_temporal_analysis.py index 9a118754..ac75f336 100644 --- a/medmodels/treatment_effect/tests/test_temporal_analysis.py +++ b/medmodels/treatment_effect/tests/test_temporal_analysis.py @@ -1,7 +1,8 @@ import unittest -from typing import List +from typing import List, Optional import pandas as pd +import pytest from medmodels.medrecord.medrecord import MedRecord from medmodels.medrecord.types import NodeIndex @@ -29,8 +30,7 @@ def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame: } ) - patients = patients.loc[patients["index"].isin(patient_list)] - return patients + return patients.loc[patients["index"].isin(patient_list)] def create_diagnoses() -> pd.DataFrame: @@ -39,13 +39,12 @@ def create_diagnoses() -> pd.DataFrame: Returns: pd.DataFrame: A diagnoses dataframe. """ - diagnoses = pd.DataFrame( + return pd.DataFrame( { "index": ["D1"], "name": ["Stroke"], } ) - return diagnoses def create_prescriptions() -> pd.DataFrame: @@ -54,13 +53,12 @@ def create_prescriptions() -> pd.DataFrame: Returns: pd.DataFrame: A prescriptions dataframe. """ - prescriptions = pd.DataFrame( + return pd.DataFrame( { "index": ["M1", "M2"], "name": ["Rivaroxaban", "Warfarin"], } ) - return prescriptions def create_edges(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -94,22 +92,19 @@ def create_edges(patient_list: List[NodeIndex]) -> pd.DataFrame: ], } ) - edges = edges.loc[edges["target"].isin(patient_list)] - return edges + return edges.loc[edges["target"].isin(patient_list)] def create_medrecord( - patient_list: List[NodeIndex] = [ - "P1", - "P2", - "P3", - ], + patient_list: Optional[List[NodeIndex]] = None, ) -> MedRecord: """Creates a MedRecord object. Returns: MedRecord: A MedRecord object. """ + if patient_list is None: + patient_list = ["P1", "P2", "P3"] patients = create_patients(patient_list=patient_list) diagnoses = create_diagnoses() prescriptions = create_prescriptions() @@ -137,17 +132,17 @@ def create_medrecord( class TestTreatmentEffect(unittest.TestCase): """""" - def setUp(self): + def setUp(self) -> None: self.medrecord = create_medrecord() - def test_find_reference_time(self): + def test_find_reference_time(self) -> None: edge = find_reference_edge( self.medrecord, node_index="P1", reference="last", connected_group="Rivaroxaban", ) - self.assertEqual(0, edge) + assert edge == 0 # adding medication time self.medrecord.add_edge( @@ -160,7 +155,7 @@ def test_find_reference_time(self): reference="last", connected_group="Rivaroxaban", ) - self.assertEqual(5, edge) + assert edge == 5 edge = find_reference_edge( self.medrecord, @@ -168,12 +163,10 @@ def test_find_reference_time(self): reference="first", connected_group="Rivaroxaban", ) - self.assertEqual(0, edge) + assert edge == 0 - def test_invalid_find_reference_time(self): - with self.assertRaisesRegex( - ValueError, "Time attribute not found in the edge attributes" - ): + def test_invalid_find_reference_time(self) -> None: + with pytest.raises(ValueError, match="Time attribute not found in the edge attributes"): find_reference_edge( self.medrecord, node_index="P1", @@ -183,9 +176,7 @@ def test_invalid_find_reference_time(self): ) node_index = "P2" - with self.assertRaisesRegex( - ValueError, f"No edge found for node {node_index} in this MedRecord" - ): + with pytest.raises(ValueError, match=f"No edge found for node {node_index} in this MedRecord"): find_reference_edge( self.medrecord, node_index=node_index, @@ -194,7 +185,7 @@ def test_invalid_find_reference_time(self): time_attribute="time", ) - def test_node_in_time_window(self): + def test_node_in_time_window(self) -> None: # check if patient has outcome a year after treatment node_found = find_node_in_time_window( self.medrecord, @@ -206,7 +197,7 @@ def test_node_in_time_window(self): reference="last", time_attribute="time", ) - self.assertTrue(node_found) + assert node_found # check if patient has outcome 30 days after treatment node_found2 = find_node_in_time_window( @@ -219,12 +210,10 @@ def test_node_in_time_window(self): reference="last", time_attribute="time", ) - self.assertFalse(node_found2) + assert not node_found2 - def test_invalid_node_in_time_window(self): - with self.assertRaisesRegex( - ValueError, "Time attribute not found in the edge attributes" - ): + def test_invalid_node_in_time_window(self) -> None: + with pytest.raises(ValueError, match="Time attribute not found in the edge attributes"): find_node_in_time_window( self.medrecord, subject_index="P3", diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index f217cef5..014ebd97 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -1,9 +1,10 @@ """Tests for the TreatmentEffect class in the treatment_effect module.""" import unittest -from typing import List +from typing import List, Optional import pandas as pd +import pytest from medmodels import MedRecord from medmodels.medrecord import edge, node @@ -35,8 +36,7 @@ def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame: } ) - patients = patients.loc[patients["index"].isin(patient_list)] - return patients + return patients.loc[patients["index"].isin(patient_list)] def create_diagnoses() -> pd.DataFrame: @@ -45,13 +45,12 @@ def create_diagnoses() -> pd.DataFrame: Returns: pd.DataFrame: A diagnoses dataframe. """ - diagnoses = pd.DataFrame( + return pd.DataFrame( { "index": ["D1"], "name": ["Stroke"], } ) - return diagnoses def create_prescriptions() -> pd.DataFrame: @@ -60,13 +59,12 @@ def create_prescriptions() -> pd.DataFrame: Returns: pd.DataFrame: A prescriptions dataframe. """ - prescriptions = pd.DataFrame( + return pd.DataFrame( { "index": ["M1", "M2"], "name": ["Rivaroxaban", "Warfarin"], } ) - return prescriptions def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -106,8 +104,7 @@ def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame: ], } ) - edges = edges.loc[edges["target"].isin(patient_list)] - return edges + return edges.loc[edges["target"].isin(patient_list)] def create_edges2(patient_list: List[NodeIndex]) -> pd.DataFrame: @@ -152,28 +149,19 @@ def create_edges2(patient_list: List[NodeIndex]) -> pd.DataFrame: ], } ) - edges = edges.loc[edges["target"].isin(patient_list)] - return edges + return edges.loc[edges["target"].isin(patient_list)] def create_medrecord( - patient_list: List[NodeIndex] = [ - "P1", - "P2", - "P3", - "P4", - "P5", - "P6", - "P7", - "P8", - "P9", - ], + patient_list: Optional[List[NodeIndex]] = None, ) -> MedRecord: """Creates a MedRecord object. Returns: MedRecord: A MedRecord object. """ + if patient_list is None: + patient_list = ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9"] patients = create_patients(patient_list=patient_list) diagnoses = create_diagnoses() prescriptions = create_prescriptions() @@ -204,79 +192,34 @@ def assert_treatment_effects_equal( test_case: unittest.TestCase, treatment_effect1: TreatmentEffect, treatment_effect2: TreatmentEffect, -): - test_case.assertEqual( - treatment_effect1._treatments_group, treatment_effect2._treatments_group - ) - test_case.assertEqual( - treatment_effect1._outcomes_group, treatment_effect2._outcomes_group - ) - test_case.assertEqual( - treatment_effect1._patients_group, treatment_effect2._patients_group - ) - test_case.assertEqual( - treatment_effect1._time_attribute, treatment_effect2._time_attribute - ) - test_case.assertEqual( - treatment_effect1._washout_period_days, treatment_effect2._washout_period_days - ) - test_case.assertEqual( - treatment_effect1._washout_period_reference, - treatment_effect2._washout_period_reference, - ) - test_case.assertEqual( - treatment_effect1._grace_period_days, treatment_effect2._grace_period_days - ) - test_case.assertEqual( - treatment_effect1._grace_period_reference, - treatment_effect2._grace_period_reference, - ) - test_case.assertEqual( - treatment_effect1._follow_up_period_days, - treatment_effect2._follow_up_period_days, - ) - test_case.assertEqual( - treatment_effect1._follow_up_period_reference, - treatment_effect2._follow_up_period_reference, - ) - test_case.assertEqual( - treatment_effect1._outcome_before_treatment_days, - treatment_effect2._outcome_before_treatment_days, - ) - test_case.assertEqual( - treatment_effect1._filter_controls_operation, - treatment_effect2._filter_controls_operation, - ) - test_case.assertEqual( - treatment_effect1._matching_method, treatment_effect2._matching_method - ) - test_case.assertEqual( - treatment_effect1._matching_essential_covariates, - treatment_effect2._matching_essential_covariates, - ) - test_case.assertEqual( - treatment_effect1._matching_one_hot_covariates, - treatment_effect2._matching_one_hot_covariates, - ) - test_case.assertEqual( - treatment_effect1._matching_model, treatment_effect2._matching_model - ) - test_case.assertEqual( - treatment_effect1._matching_number_of_neighbors, - treatment_effect2._matching_number_of_neighbors, - ) - test_case.assertEqual( - treatment_effect1._matching_hyperparam, treatment_effect2._matching_hyperparam - ) +) -> None: + assert treatment_effect1._treatments_group == treatment_effect2._treatments_group + assert treatment_effect1._outcomes_group == treatment_effect2._outcomes_group + assert treatment_effect1._patients_group == treatment_effect2._patients_group + assert treatment_effect1._time_attribute == treatment_effect2._time_attribute + assert treatment_effect1._washout_period_days == treatment_effect2._washout_period_days + assert treatment_effect1._washout_period_reference == treatment_effect2._washout_period_reference + assert treatment_effect1._grace_period_days == treatment_effect2._grace_period_days + assert treatment_effect1._grace_period_reference == treatment_effect2._grace_period_reference + assert treatment_effect1._follow_up_period_days == treatment_effect2._follow_up_period_days + assert treatment_effect1._follow_up_period_reference == treatment_effect2._follow_up_period_reference + assert treatment_effect1._outcome_before_treatment_days == treatment_effect2._outcome_before_treatment_days + assert treatment_effect1._filter_controls_operation == treatment_effect2._filter_controls_operation + assert treatment_effect1._matching_method == treatment_effect2._matching_method + assert treatment_effect1._matching_essential_covariates == treatment_effect2._matching_essential_covariates + assert treatment_effect1._matching_one_hot_covariates == treatment_effect2._matching_one_hot_covariates + assert treatment_effect1._matching_model == treatment_effect2._matching_model + assert treatment_effect1._matching_number_of_neighbors == treatment_effect2._matching_number_of_neighbors + assert treatment_effect1._matching_hyperparam == treatment_effect2._matching_hyperparam class TestTreatmentEffect(unittest.TestCase): """Class to test the TreatmentEffect class in the treatment_effect module.""" - def setUp(self): + def setUp(self) -> None: self.medrecord = create_medrecord() - def test_init(self): + def test_init(self) -> None: # Initialize TreatmentEffect object tee = TreatmentEffect( treatment="Rivaroxaban", @@ -292,7 +235,7 @@ def test_init(self): assert_treatment_effects_equal(self, tee, tee_builder) - def test_default_properties(self): + def test_default_properties(self) -> None: tee = TreatmentEffect( treatment="Rivaroxaban", outcome="Stroke", @@ -312,7 +255,7 @@ def test_default_properties(self): assert_treatment_effects_equal(self, tee, tee_builder) - def test_check_medrecord(self): + def test_check_medrecord(self) -> None: tee = ( TreatmentEffect.builder() .with_outcome("Stroke") @@ -320,9 +263,7 @@ def test_check_medrecord(self): .build() ) - with self.assertRaisesRegex( - ValueError, "Treatment group not found in the MedRecord" - ): + with pytest.raises(ValueError, match="Treatment group not found in the MedRecord"): tee.estimate._check_medrecord(medrecord=self.medrecord) tee2 = ( @@ -332,9 +273,7 @@ def test_check_medrecord(self): .build() ) - with self.assertRaisesRegex( - ValueError, "Outcome group not found in the MedRecord" - ): + with pytest.raises(ValueError, match="Outcome group not found in the MedRecord"): tee2.estimate._check_medrecord(medrecord=self.medrecord) patient_group = "subjects" @@ -346,12 +285,10 @@ def test_check_medrecord(self): .build() ) - with self.assertRaisesRegex( - ValueError, f"Patient group {patient_group} not found in the MedRecord" - ): + with pytest.raises(ValueError, match=f"Patient group {patient_group} not found in the MedRecord"): tee3.estimate._check_medrecord(medrecord=self.medrecord) - def test_find_treated_patients(self): + def test_find_treated_patients(self) -> None: tee = ( TreatmentEffect.builder() .with_outcome("Stroke") @@ -360,18 +297,16 @@ def test_find_treated_patients(self): ) treated_group = tee._find_treated_patients(self.medrecord) - self.assertEqual(treated_group, set({"P2", "P3", "P6"})) + assert treated_group == set({"P2", "P3", "P6"}) # no treatment_group patients = set(self.medrecord.nodes_in_group("patients")) medrecord2 = create_medrecord(list(patients - treated_group)) - with self.assertRaisesRegex( - ValueError, "No patients found for the treatment groups in this MedRecord." - ): + with pytest.raises(ValueError, match="No patients found for the treatment groups in this MedRecord."): tee.estimate.subject_counts(medrecord=medrecord2) - def test_find_groups(self): + def test_find_groups(self) -> None: tee = ( TreatmentEffect.builder() .with_outcome("Stroke") @@ -382,12 +317,12 @@ def test_find_groups(self): treatment_true, treatment_false, control_true, control_false = tee._find_groups( self.medrecord ) - self.assertEqual(treatment_true, set({"P2", "P3"})) - self.assertEqual(treatment_false, set({"P6"})) - self.assertEqual(control_true, set({"P1", "P4", "P7"})) - self.assertEqual(control_false, set({"P5", "P8", "P9"})) + assert treatment_true == set({"P2", "P3"}) + assert treatment_false == set({"P6"}) + assert control_true == set({"P1", "P4", "P7"}) + assert control_false == set({"P5", "P8", "P9"}) - def test_compute_subject_counts(self): + def test_compute_subject_counts(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -396,7 +331,7 @@ def test_compute_subject_counts(self): ) counts = tee.estimate._compute_subject_counts(self.medrecord) - self.assertEqual(counts, (2, 1, 3, 3)) + assert counts == (2, 1, 3, 3) # test value errors if no subjects are found treatment_true, treatment_false, control_true, control_false = tee._find_groups( @@ -407,24 +342,18 @@ def test_compute_subject_counts(self): ) medrecord2 = create_medrecord(patient_list=list(all_patients - control_false)) - with self.assertRaisesRegex( - ValueError, "No subjects found in the control false group" - ): + with pytest.raises(ValueError, match="No subjects found in the control false group"): tee.estimate.subject_counts(medrecord=medrecord2) medrecord3 = create_medrecord(patient_list=list(all_patients - treatment_false)) - with self.assertRaisesRegex( - ValueError, "No subjects found in the treatment false group" - ): + with pytest.raises(ValueError, match="No subjects found in the treatment false group"): tee.estimate.subject_counts(medrecord=medrecord3) medrecord4 = create_medrecord(patient_list=list(all_patients - control_true)) - with self.assertRaisesRegex( - ValueError, "No subjects found in the control true group" - ): + with pytest.raises(ValueError, match="No subjects found in the control true group"): tee.estimate.subject_counts(medrecord=medrecord4) - def test_subject_counts(self): + def test_subject_counts(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -441,7 +370,7 @@ def test_subject_counts(self): self.assertDictEqual(counts_tee, counts_test) - def test_subjects_contigency_table(self): + def test_subjects_contigency_table(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -458,7 +387,7 @@ def test_subjects_contigency_table(self): subjects_tee = tee.estimate.subjects_contingency_table(self.medrecord) self.assertDictEqual(subjects_test, subjects_tee) - def test_follow_up_period(self): + def test_follow_up_period(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -467,7 +396,7 @@ def test_follow_up_period(self): .build() ) - self.assertEqual(tee._follow_up_period_days, 30) + assert tee._follow_up_period_days == 30 counts_test = { "treatment_true": 1, @@ -479,7 +408,7 @@ def test_follow_up_period(self): self.assertDictEqual(counts_tee, counts_test) - def test_grace_period(self): + def test_grace_period(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -488,7 +417,7 @@ def test_grace_period(self): .build() ) - self.assertEqual(tee._grace_period_days, 10) + assert tee._grace_period_days == 10 counts_test = { "treatment_true": 1, @@ -500,7 +429,7 @@ def test_grace_period(self): self.assertDictEqual(counts_tee, counts_test) - def test_washout_period(self): + def test_washout_period(self) -> None: washout_dict = {"Warfarin": 30} tee = ( @@ -518,8 +447,8 @@ def test_washout_period(self): self.medrecord, treated_group ) - self.assertEqual(treated_group, set({"P3", "P6"})) - self.assertEqual(washout_nodes, set({"P2"})) + assert treated_group == set({"P3", "P6"}) + assert washout_nodes == set({"P2"}) # smaller washout period washout_dict2 = {"Warfarin": 10} @@ -539,10 +468,10 @@ def test_washout_period(self): self.medrecord, treated_group ) - self.assertEqual(treated_group, set({"P2", "P3", "P6"})) - self.assertEqual(washout_nodes, set({})) + assert treated_group == set({"P2", "P3", "P6"}) + assert washout_nodes == set({}) - def test_outcome_before_treatment(self): + def test_outcome_before_treatment(self) -> None: # case 1 find outcomes for default tee tee = ( TreatmentEffect.builder() @@ -554,9 +483,9 @@ def test_outcome_before_treatment(self): treated_group, treatment_true, outcome_before_treatment_nodes = ( tee._find_outcomes(self.medrecord, treated_group) ) - self.assertEqual(treated_group, set({"P2", "P3", "P6"})) - self.assertEqual(treatment_true, set({"P2", "P3"})) - self.assertEqual(outcome_before_treatment_nodes, set()) + assert treated_group == set({"P2", "P3", "P6"}) + assert treatment_true == set({"P2", "P3"}) + assert outcome_before_treatment_nodes == set() # case 2 set exclusion time for outcome before treatment tee2 = ( @@ -567,15 +496,15 @@ def test_outcome_before_treatment(self): .build() ) - self.assertEqual(tee2._outcome_before_treatment_days, 30) + assert tee2._outcome_before_treatment_days == 30 treated_group = tee2._find_treated_patients(self.medrecord) treated_group, treatment_true, outcome_before_treatment_nodes = ( tee2._find_outcomes(self.medrecord, treated_group) ) - self.assertEqual(treated_group, set({"P2", "P6"})) - self.assertEqual(treatment_true, set({"P2"})) - self.assertEqual(outcome_before_treatment_nodes, set({"P3"})) + assert treated_group == set({"P2", "P6"}) + assert treatment_true == set({"P2"}) + assert outcome_before_treatment_nodes == set({"P3"}) # case 3 no outcome @@ -589,12 +518,10 @@ def test_outcome_before_treatment(self): .build() ) - with self.assertRaisesRegex( - ValueError, "No outcomes found in the MedRecord for group " - ): + with pytest.raises(ValueError, match="No outcomes found in the MedRecord for group "): tee3._find_outcomes(medrecord=self.medrecord, treated_group=treated_group) - def test_filter_controls(self): + def test_filter_controls(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -635,7 +562,7 @@ def test_filter_controls(self): self.assertDictEqual(counts_tee2, counts_test2) - def test_nearest_neighbors(self): + def test_nearest_neighbors(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -648,11 +575,11 @@ def test_nearest_neighbors(self): # multiple patients are equally similar to the treatment group # these are exact macthes and should always be included - self.assertIn("P4", subjects["control_true"]) - self.assertIn("P5", subjects["control_false"]) - self.assertIn("P8", subjects["control_false"]) + assert "P4" in subjects["control_true"] + assert "P5" in subjects["control_false"] + assert "P8" in subjects["control_false"] - def test_propensity_matching(self): + def test_propensity_matching(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -663,11 +590,11 @@ def test_propensity_matching(self): subjects = tee.estimate.subjects_contingency_table(self.medrecord) - self.assertIn("P4", subjects["control_true"]) - self.assertIn("P5", subjects["control_false"]) - self.assertIn("P1", subjects["control_true"]) + assert "P4" in subjects["control_true"] + assert "P5" in subjects["control_false"] + assert "P1" in subjects["control_true"] - def test_find_controls(self): + def test_find_controls(self) -> None: tee = ( TreatmentEffect.builder() .with_treatment("Rivaroxaban") @@ -683,12 +610,10 @@ def test_find_controls(self): control_group=patients - treated_group, treated_group=patients.intersection(treated_group), ) - self.assertEqual(control_true, {"P1", "P4", "P7"}) - self.assertEqual(control_false, {"P5", "P8", "P9"}) + assert control_true == {"P1", "P4", "P7"} + assert control_false == {"P5", "P8", "P9"} - with self.assertRaisesRegex( - ValueError, "No patients found for control groups in this MedRecord." - ): + with pytest.raises(ValueError, match="No patients found for control groups in this MedRecord."): tee._find_controls( self.medrecord, control_group=patients - treated_group, @@ -705,16 +630,14 @@ def test_find_controls(self): self.medrecord.add_group("Headache") - with self.assertRaisesRegex( - ValueError, "No outcomes found in the MedRecord for group." - ): + with pytest.raises(ValueError, match="No outcomes found in the MedRecord for group."): tee2._find_controls( self.medrecord, control_group=patients - treated_group, treated_group=patients.intersection(treated_group), ) - def test_metrics(self): + def test_metrics(self) -> None: """Test the metrics of the TreatmentEffect class.""" tee = ( TreatmentEffect.builder() @@ -731,7 +654,7 @@ def test_metrics(self): self.assertAlmostEqual(tee.estimate.hazard_ratio(self.medrecord), 4 / 3) self.assertAlmostEqual(tee.estimate.number_needed_to_treat(self.medrecord), 6) - def test_full_report(self): + def test_full_report(self) -> None: """Test the full reporting of the TreatmentEffect class.""" tee = ( TreatmentEffect.builder() @@ -755,7 +678,7 @@ def test_full_report(self): } self.assertDictEqual(report_test, full_report) - def test_continuous_estimators_report(self): + def test_continuous_estimators_report(self) -> None: """Test the continuous report of the TreatmentEffect class.""" tee = ( TreatmentEffect.builder() diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index cfe96ab9..2fb823e4 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -11,24 +11,26 @@ from __future__ import annotations import logging -from typing import Any, Dict, Literal, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set, Tuple -from medmodels import MedRecord from medmodels.medrecord import node -from medmodels.medrecord.querying import NodeOperation -from medmodels.medrecord.types import ( - Group, - MedRecordAttribute, - MedRecordAttributeInputList, - NodeIndex, -) from medmodels.treatment_effect.builder import TreatmentEffectBuilder from medmodels.treatment_effect.estimate import Estimate -from medmodels.treatment_effect.matching.algorithms.propensity_score import Model -from medmodels.treatment_effect.matching.matching import MatchingMethod from medmodels.treatment_effect.report import Report from medmodels.treatment_effect.temporal_analysis import find_node_in_time_window +if TYPE_CHECKING: + from medmodels import MedRecord + from medmodels.medrecord.querying import NodeOperation + from medmodels.medrecord.types import ( + Group, + MedRecordAttribute, + MedRecordAttributeInputList, + NodeIndex, + ) + from medmodels.treatment_effect.matching.algorithms.propensity_score import Model + from medmodels.treatment_effect.matching.matching import MatchingMethod + class TreatmentEffect: """This class facilitates the analysis of treatment effects over time and across different patient groups.""" @@ -85,7 +87,7 @@ def _set_configuration( outcome: Group, patients_group: Group = "patients", time_attribute: MedRecordAttribute = "time", - washout_period_days: Dict[str, int] = dict(), + washout_period_days: Optional[Dict[str, int]] = None, washout_period_reference: Literal["first", "last"] = "first", grace_period_days: int = 0, grace_period_reference: Literal["first", "last"] = "last", @@ -94,8 +96,8 @@ def _set_configuration( outcome_before_treatment_days: Optional[int] = None, filter_controls_operation: Optional[NodeOperation] = None, matching_method: Optional[MatchingMethod] = None, - matching_essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - matching_one_hot_covariates: MedRecordAttributeInputList = ["gender"], + matching_essential_covariates: MedRecordAttributeInputList = None, + matching_one_hot_covariates: MedRecordAttributeInputList = None, matching_model: Model = "logit", matching_number_of_neighbors: int = 1, matching_hyperparam: Optional[Dict[str, Any]] = None, @@ -145,6 +147,12 @@ def _set_configuration( matching_hyperparam (Optional[Dict[str, Any]], optional): The hyperparameters for the matching model. Defaults to None. """ + if matching_one_hot_covariates is None: + matching_one_hot_covariates = ["gender"] + if washout_period_days is None: + washout_period_days = {} + if matching_essential_covariates is None: + matching_essential_covariates = ["gender", "age"] treatment_effect._patients_group = patients_group treatment_effect._time_attribute = time_attribute @@ -238,8 +246,9 @@ def _find_treated_patients(self, medrecord: MedRecord) -> Set[NodeIndex]: ) ) if not treated_group: + msg = "No patients found for the treatment groups in this MedRecord." raise ValueError( - "No patients found for the treatment groups in this MedRecord." + msg ) return treated_group @@ -275,8 +284,9 @@ def _find_outcomes( # Find nodes with the outcomes outcomes = medrecord.nodes_in_group(self._outcomes_group) if not outcomes: + msg = f"No outcomes found in the MedRecord for group {self._outcomes_group}" raise ValueError( - f"No outcomes found in the MedRecord for group {self._outcomes_group}" + msg ) for outcome in outcomes: @@ -389,7 +399,7 @@ def _find_controls( medrecord: MedRecord, control_group: Set[NodeIndex], treated_group: Set[NodeIndex], - rejected_nodes: Set[NodeIndex] = set(), + rejected_nodes: Optional[Set[NodeIndex]] = None, filter_controls_operation: Optional[NodeOperation] = None, ) -> Tuple[Set[NodeIndex], Set[NodeIndex]]: """Identifies control groups among patients who did not undergo the specified treatments. @@ -427,6 +437,8 @@ def _find_controls( outcome group. """ # Apply the filter to the control group if specified + if rejected_nodes is None: + rejected_nodes = set() if filter_controls_operation: control_group = ( set(medrecord.select_nodes(filter_controls_operation)) & control_group @@ -434,14 +446,16 @@ def _find_controls( control_group = control_group - treated_group - rejected_nodes if len(control_group) == 0: - raise ValueError("No patients found for control groups in this MedRecord.") + msg = "No patients found for control groups in this MedRecord." + raise ValueError(msg) control_true = set() control_false = set() outcomes = medrecord.nodes_in_group(self._outcomes_group) if not outcomes: + msg = f"No outcomes found in the MedRecord for group {self._outcomes_group}" raise ValueError( - f"No outcomes found in the MedRecord for group {self._outcomes_group}" + msg ) # Finding the patients that had the outcome in the control group From 9d40b65edd59894722a2e93b15b6b32d568b41e7 Mon Sep 17 00:00:00 2001 From: FloLimebit Date: Fri, 11 Oct 2024 09:34:23 -0600 Subject: [PATCH 4/7] ruff check fixes --- medmodels/medrecord/_overview.py | 2 +- medmodels/medrecord/medrecord.py | 18 ++++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/medmodels/medrecord/_overview.py b/medmodels/medrecord/_overview.py index e22a16a1..f4876461 100644 --- a/medmodels/medrecord/_overview.py +++ b/medmodels/medrecord/_overview.py @@ -184,7 +184,7 @@ def prettify_table( info_order = ["min", "max", "mean", "values"] - for group in data.keys(): + for group in data: # determine longest group name and count lengths[0] = max(len(str(group)), lengths[0]) diff --git a/medmodels/medrecord/medrecord.py b/medmodels/medrecord/medrecord.py index 13e2bb96..73698d74 100644 --- a/medmodels/medrecord/medrecord.py +++ b/medmodels/medrecord/medrecord.py @@ -91,7 +91,6 @@ def __init__( group_header (str): Header for group column, i.e. 'Group Nodes'. decimal (int): Decimal point to round the float values to. """ - self.data = data self.group_header = group_header self.decimal = decimal @@ -843,22 +842,21 @@ def add_edges( edges ): return self.add_edges_pandas(edges, group) - elif is_polars_edge_dataframe_input( + if is_polars_edge_dataframe_input( edges ) or is_polars_edge_dataframe_input_list(edges): return self.add_edges_polars(edges, group) - else: - edge_indices = self._medrecord.add_edges(edges) + edge_indices = self._medrecord.add_edges(edges) - if group is None: - return edge_indices + if group is None: + return edge_indices - if not self.contains_group(group): - self.add_group(group) + if not self.contains_group(group): + self.add_group(group) - self.add_edge_to_group(group, edge_indices) + self.add_edge_to_group(group, edge_indices) - return edge_indices + return edge_indices def add_edges_pandas( self, From 17bd45e096d9d0fad2a3b0b31a768d0c34a35ca4 Mon Sep 17 00:00:00 2001 From: FloLimebit Date: Fri, 11 Oct 2024 09:48:48 -0600 Subject: [PATCH 5/7] ruff formatting --- docs/developer_guide/example_docstrings.py | 2 +- medmodels/_medmodels.pyi | 1 - medmodels/medrecord/medrecord.py | 4 +- medmodels/medrecord/schema.py | 8 +- medmodels/medrecord/tests/test_datatype.py | 5 +- medmodels/medrecord/tests/test_indexers.py | 684 +++++++++++++--- medmodels/medrecord/tests/test_medrecord.py | 45 +- medmodels/medrecord/tests/test_querying.py | 750 ++++++++++++++---- medmodels/medrecord/tests/test_schema.py | 39 +- medmodels/treatment_effect/estimate.py | 16 +- .../algorithms/classic_distance_models.py | 4 +- .../tests/test_temporal_analysis.py | 12 +- .../tests/test_treatment_effect.py | 78 +- .../treatment_effect/treatment_effect.py | 12 +- 14 files changed, 1363 insertions(+), 297 deletions(-) diff --git a/docs/developer_guide/example_docstrings.py b/docs/developer_guide/example_docstrings.py index 7d1306a3..b8e740c6 100644 --- a/docs/developer_guide/example_docstrings.py +++ b/docs/developer_guide/example_docstrings.py @@ -8,7 +8,7 @@ def example_function_args( param2: str | int, optional_param: list[str] | None = None, *args: float | str, - **kwargs: Dict[str, Any] + **kwargs: Dict[str, Any], ) -> tuple[bool, list[str]]: """Example function with PEP 484 type annotations and PEP 563 future annotations. diff --git a/medmodels/_medmodels.pyi b/medmodels/_medmodels.pyi index f8ebf367..543c4582 100644 --- a/medmodels/_medmodels.pyi +++ b/medmodels/_medmodels.pyi @@ -1,4 +1,3 @@ - from enum import Enum from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union diff --git a/medmodels/medrecord/medrecord.py b/medmodels/medrecord/medrecord.py index 7a98c824..c7640516 100644 --- a/medmodels/medrecord/medrecord.py +++ b/medmodels/medrecord/medrecord.py @@ -776,9 +776,9 @@ def add_edges( edges ): return self.add_edges_pandas(edges, group) - if is_polars_edge_dataframe_input( + if is_polars_edge_dataframe_input(edges) or is_polars_edge_dataframe_input_list( edges - ) or is_polars_edge_dataframe_input_list(edges): + ): return self.add_edges_polars(edges, group) if is_edge_tuple(edges): edges = [edges] diff --git a/medmodels/medrecord/schema.py b/medmodels/medrecord/schema.py index bcc4e85d..5b0cd5fc 100644 --- a/medmodels/medrecord/schema.py +++ b/medmodels/medrecord/schema.py @@ -252,8 +252,12 @@ class GroupSchema: def __init__( self, *, - nodes: Optional[Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]]] = None, - edges: Optional[Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]]] = None, + nodes: Optional[ + Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]] + ] = None, + edges: Optional[ + Dict[MedRecordAttribute, Union[DataType, Tuple[DataType, AttributeType]]] + ] = None, strict: bool = False, ) -> None: """Initializes a new instance of GroupSchema. diff --git a/medmodels/medrecord/tests/test_datatype.py b/medmodels/medrecord/tests/test_datatype.py index 331519bb..6fd3c1c0 100644 --- a/medmodels/medrecord/tests/test_datatype.py +++ b/medmodels/medrecord/tests/test_datatype.py @@ -107,7 +107,10 @@ def test_union(self) -> None: assert str(union) == "Union(String, Union(Int, Bool))" - assert union.__repr__() == "DataType.Union(DataType.String, DataType.Union(DataType.Int, DataType.Bool))" + assert ( + union.__repr__() + == "DataType.Union(DataType.String, DataType.Union(DataType.Int, DataType.Bool))" + ) assert mr.Union(mr.String(), mr.Int()) == mr.Union(mr.String(), mr.Int()) assert mr.Union(mr.String(), mr.Int()) != mr.Union(mr.Int(), mr.String()) diff --git a/medmodels/medrecord/tests/test_indexers.py b/medmodels/medrecord/tests/test_indexers.py index b8879e82..49799a57 100644 --- a/medmodels/medrecord/tests/test_indexers.py +++ b/medmodels/medrecord/tests/test_indexers.py @@ -54,7 +54,10 @@ def test_node_getitem(self) -> None: with pytest.raises(ValueError): medrecord.node[0, ::1] - assert medrecord.node[[0, 1]] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}} + assert medrecord.node[[0, 1]] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + } with pytest.raises(IndexError): medrecord.node[[0, 50]] @@ -69,7 +72,10 @@ def test_node_getitem(self) -> None: with pytest.raises(KeyError): medrecord.node[[0, 1], "lorem"] - assert medrecord.node[[0, 1], ["foo", "bar"]] == {0: {"foo": "bar", "bar": "foo"}, 1: {"foo": "bar", "bar": "foo"}} + assert medrecord.node[[0, 1], ["foo", "bar"]] == { + 0: {"foo": "bar", "bar": "foo"}, + 1: {"foo": "bar", "bar": "foo"}, + } # Accessing a non-existing key should fail with pytest.raises(KeyError): @@ -79,7 +85,10 @@ def test_node_getitem(self) -> None: with pytest.raises(KeyError): medrecord.node[[0, 1], ["foo", "lorem"]] - assert medrecord.node[[0, 1], :] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}} + assert medrecord.node[[0, 1], :] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + } with pytest.raises(ValueError): medrecord.node[[0, 1], 1:] @@ -88,7 +97,10 @@ def test_node_getitem(self) -> None: with pytest.raises(ValueError): medrecord.node[[0, 1], ::1] - assert medrecord.node[node().index() >= 2] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[node().index() >= 2] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Empty query should not fail assert medrecord.node[node().index() > 3] == {} @@ -99,7 +111,10 @@ def test_node_getitem(self) -> None: with pytest.raises(KeyError): medrecord.node[node().index() >= 2, "test"] - assert medrecord.node[node().index() >= 2, ["foo", "bar"]] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[node().index() >= 2, ["foo", "bar"]] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Accessing a non-existing key should fail with pytest.raises(KeyError): @@ -109,7 +124,10 @@ def test_node_getitem(self) -> None: with pytest.raises(KeyError): medrecord.node[node().index() < 2, ["foo", "lorem"]] - assert medrecord.node[node().index() >= 2, :] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[node().index() >= 2, :] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.node[node().index() >= 2, 1:] @@ -118,7 +136,12 @@ def test_node_getitem(self) -> None: with pytest.raises(ValueError): medrecord.node[node().index() >= 2, ::1] - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.node[1:] @@ -140,7 +163,12 @@ def test_node_getitem(self) -> None: with pytest.raises(ValueError): medrecord.node[::1, "foo"] - assert medrecord.node[:, ["foo", "bar"]] == {0: {"foo": "bar", "bar": "foo"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:, ["foo", "bar"]] == { + 0: {"foo": "bar", "bar": "foo"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Accessing a non-existing key should fail with pytest.raises(KeyError): @@ -157,7 +185,12 @@ def test_node_getitem(self) -> None: with pytest.raises(ValueError): medrecord.node[::1, ["foo", "bar"]] - assert medrecord.node[:, :] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:, :] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.node[1:, :] @@ -177,7 +210,12 @@ def test_node_setitem(self) -> None: medrecord = create_medrecord() medrecord.node[0] = {"foo": "bar", "bar": "test"} - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Updating a non-existing node should fail @@ -186,15 +224,30 @@ def test_node_setitem(self) -> None: medrecord = create_medrecord() medrecord.node[0, "foo"] = "test" - assert medrecord.node[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[0, ["foo", "bar"]] = "test" - assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[0, :] = "test" - assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.node[0, 1:] = "test" @@ -205,15 +258,30 @@ def test_node_setitem(self) -> None: medrecord = create_medrecord() medrecord.node[[0, 1], "foo"] = "test" - assert medrecord.node[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[[0, 1], ["foo", "bar"]] = "test" - assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[[0, 1], :] = "test" - assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.node[[0, 1], 1:] = "test" @@ -224,7 +292,12 @@ def test_node_setitem(self) -> None: medrecord = create_medrecord() medrecord.node[node().index() >= 2] = {"foo": "bar", "bar": "test"} - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "test"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "test"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Empty query should not fail @@ -232,15 +305,30 @@ def test_node_setitem(self) -> None: medrecord = create_medrecord() medrecord.node[node().index() >= 2, "foo"] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "foo"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "foo"}, + 3: {"foo": "test", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() >= 2, ["foo", "bar"]] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() >= 2, :] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.node[node().index() >= 2, 1:] = "test" @@ -251,7 +339,12 @@ def test_node_setitem(self) -> None: medrecord = create_medrecord() medrecord.node[:, "foo"] = "test" - assert medrecord.node[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "foo"}, 2: {"foo": "test", "bar": "foo"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "foo"}, + 2: {"foo": "test", "bar": "foo"}, + 3: {"foo": "test", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.node[1:, "foo"] = "test" @@ -262,7 +355,12 @@ def test_node_setitem(self) -> None: medrecord = create_medrecord() medrecord.node[:, ["foo", "bar"]] = "test" - assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.node[1:, ["foo", "bar"]] = "test" @@ -273,7 +371,12 @@ def test_node_setitem(self) -> None: medrecord = create_medrecord() medrecord.node[:, :] = "test" - assert medrecord.node[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.node[1:, :] = "test" @@ -292,66 +395,159 @@ def test_node_setitem(self) -> None: medrecord = create_medrecord() medrecord.node[0, "test"] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[0, ["test", "test2"]] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", + }, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[[0, 1], "test"] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[[0, 1], ["test", "test2"]] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", + }, + 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() >= 2, "test"] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo", "test": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() >= 2, ["test", "test2"]] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, + } medrecord = create_medrecord() medrecord.node[:, "test"] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test"}, 2: {"foo": "bar", "bar": "foo", "test": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "test": "test"}, + 2: {"foo": "bar", "bar": "foo", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test"}, + } medrecord = create_medrecord() medrecord.node[:, ["test", "test2"]] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}} + assert medrecord.node[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", + }, + 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, + } # Adding and updating attributes medrecord = create_medrecord() medrecord.node[[0, 1], "lorem"] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[[0, 1], ["lorem", "test"]] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() < 2, "lorem"] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[node().index() < 2, ["lorem", "test"]] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.node[:, "lorem"] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, 3: {"foo": "bar", "bar": "test", "lorem": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 3: {"foo": "bar", "bar": "test", "lorem": "test"}, + } medrecord = create_medrecord() medrecord.node[:, ["lorem", "test"]] = "test" - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}, + } def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[0, "foo"] - assert medrecord.node[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing from a non-existing node should fail @@ -365,7 +561,12 @@ def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[0, ["foo", "bar"]] - assert medrecord.node[:] == {0: {"lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail @@ -374,7 +575,12 @@ def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[0, :] - assert medrecord.node[:] == {0: {}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): del medrecord.node[0, 1:] @@ -385,7 +591,12 @@ def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[[0, 1], "foo"] - assert medrecord.node[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing from a non-existing node should fail @@ -399,7 +610,12 @@ def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[[0, 1], ["foo", "bar"]] - assert medrecord.node[:] == {0: {"lorem": "ipsum"}, 1: {}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"lorem": "ipsum"}, + 1: {}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail @@ -413,7 +629,12 @@ def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[[0, 1], :] - assert medrecord.node[:] == {0: {}, 1: {}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {}, + 1: {}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): del medrecord.node[[0, 1], 1:] @@ -424,12 +645,22 @@ def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[node().index() >= 2, "foo"] - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"bar": "foo"}, 3: {"bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"bar": "foo"}, + 3: {"bar": "test"}, + } medrecord = create_medrecord() # Empty query should not fail del medrecord.node[node().index() > 3, "foo"] - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail @@ -438,7 +669,12 @@ def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[node().index() >= 2, ["foo", "bar"]] - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {}, 3: {}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {}, + 3: {}, + } medrecord = create_medrecord() # Removing a non-existing key should fail @@ -452,7 +688,12 @@ def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[node().index() >= 2, :] - assert medrecord.node[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {}, 3: {}} + assert medrecord.node[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {}, + 3: {}, + } with pytest.raises(ValueError): del medrecord.node[node().index() >= 2, 1:] @@ -463,7 +704,12 @@ def test_node_delitem(self) -> None: medrecord = create_medrecord() del medrecord.node[:, "foo"] - assert medrecord.node[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"bar": "foo"}, 2: {"bar": "foo"}, 3: {"bar": "test"}} + assert medrecord.node[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"bar": "foo"}, + 2: {"bar": "foo"}, + 3: {"bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail @@ -545,7 +791,10 @@ def test_edge_getitem(self) -> None: with pytest.raises(ValueError): medrecord.edge[0, ::1] - assert medrecord.edge[[0, 1]] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}} + assert medrecord.edge[[0, 1]] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + } with pytest.raises(IndexError): medrecord.edge[[0, 50]] @@ -560,7 +809,10 @@ def test_edge_getitem(self) -> None: with pytest.raises(KeyError): medrecord.edge[[0, 1], "lorem"] - assert medrecord.edge[[0, 1], ["foo", "bar"]] == {0: {"foo": "bar", "bar": "foo"}, 1: {"foo": "bar", "bar": "foo"}} + assert medrecord.edge[[0, 1], ["foo", "bar"]] == { + 0: {"foo": "bar", "bar": "foo"}, + 1: {"foo": "bar", "bar": "foo"}, + } # Accessing a non-existing key should fail with pytest.raises(KeyError): @@ -570,7 +822,10 @@ def test_edge_getitem(self) -> None: with pytest.raises(KeyError): medrecord.edge[[0, 1], ["foo", "lorem"]] - assert medrecord.edge[[0, 1], :] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}} + assert medrecord.edge[[0, 1], :] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + } with pytest.raises(ValueError): medrecord.edge[[0, 1], 1:] @@ -579,7 +834,10 @@ def test_edge_getitem(self) -> None: with pytest.raises(ValueError): medrecord.edge[[0, 1], ::1] - assert medrecord.edge[edge().index() >= 2] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[edge().index() >= 2] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Empty query should not fail assert medrecord.edge[edge().index() > 3] == {} @@ -590,7 +848,10 @@ def test_edge_getitem(self) -> None: with pytest.raises(KeyError): medrecord.edge[edge().index() >= 2, "test"] - assert medrecord.edge[edge().index() >= 2, ["foo", "bar"]] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[edge().index() >= 2, ["foo", "bar"]] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Accessing a non-existing key should fail with pytest.raises(KeyError): @@ -600,7 +861,10 @@ def test_edge_getitem(self) -> None: with pytest.raises(KeyError): medrecord.edge[edge().index() < 2, ["foo", "lorem"]] - assert medrecord.edge[edge().index() >= 2, :] == {2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[edge().index() >= 2, :] == { + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, 1:] @@ -609,7 +873,12 @@ def test_edge_getitem(self) -> None: with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, ::1] - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.edge[1:] @@ -631,7 +900,12 @@ def test_edge_getitem(self) -> None: with pytest.raises(ValueError): medrecord.edge[::1, "foo"] - assert medrecord.edge[:, ["foo", "bar"]] == {0: {"foo": "bar", "bar": "foo"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:, ["foo", "bar"]] == { + 0: {"foo": "bar", "bar": "foo"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } # Accessing a non-existing key should fail with pytest.raises(KeyError): @@ -648,7 +922,12 @@ def test_edge_getitem(self) -> None: with pytest.raises(ValueError): medrecord.edge[::1, ["foo", "bar"]] - assert medrecord.edge[:, :] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:, :] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.edge[1:, :] @@ -668,7 +947,12 @@ def test_edge_setitem(self) -> None: medrecord = create_medrecord() medrecord.edge[0] = {"foo": "bar", "bar": "test"} - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Updating a non-existing edge should fail @@ -677,15 +961,30 @@ def test_edge_setitem(self) -> None: medrecord = create_medrecord() medrecord.edge[0, "foo"] = "test" - assert medrecord.edge[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[0, ["foo", "bar"]] = "test" - assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[0, :] = "test" - assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.edge[0, 1:] = "test" @@ -696,15 +995,30 @@ def test_edge_setitem(self) -> None: medrecord = create_medrecord() medrecord.edge[[0, 1], "foo"] = "test" - assert medrecord.edge[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[[0, 1], ["foo", "bar"]] = "test" - assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[[0, 1], :] = "test" - assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.edge[[0, 1], 1:] = "test" @@ -715,7 +1029,12 @@ def test_edge_setitem(self) -> None: medrecord = create_medrecord() medrecord.edge[edge().index() >= 2] = {"foo": "bar", "bar": "test"} - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "test"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "test"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Empty query should not fail @@ -723,15 +1042,30 @@ def test_edge_setitem(self) -> None: medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, "foo"] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "foo"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "foo"}, + 3: {"foo": "test", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, ["foo", "bar"]] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, :] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.edge[edge().index() >= 2, 1:] = "test" @@ -742,7 +1076,12 @@ def test_edge_setitem(self) -> None: medrecord = create_medrecord() medrecord.edge[:, "foo"] = "test" - assert medrecord.edge[:] == {0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "foo"}, 2: {"foo": "test", "bar": "foo"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "foo"}, + 2: {"foo": "test", "bar": "foo"}, + 3: {"foo": "test", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.edge[1:, "foo"] = "test" @@ -753,7 +1092,12 @@ def test_edge_setitem(self) -> None: medrecord = create_medrecord() medrecord.edge[:, ["foo", "bar"]] = "test" - assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "ipsum"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.edge[1:, ["foo", "bar"]] = "test" @@ -764,7 +1108,12 @@ def test_edge_setitem(self) -> None: medrecord = create_medrecord() medrecord.edge[:, :] = "test" - assert medrecord.edge[:] == {0: {"foo": "test", "bar": "test", "lorem": "test"}, 1: {"foo": "test", "bar": "test"}, 2: {"foo": "test", "bar": "test"}, 3: {"foo": "test", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "test", "bar": "test", "lorem": "test"}, + 1: {"foo": "test", "bar": "test"}, + 2: {"foo": "test", "bar": "test"}, + 3: {"foo": "test", "bar": "test"}, + } with pytest.raises(ValueError): medrecord.edge[1:, :] = "test" @@ -783,66 +1132,159 @@ def test_edge_setitem(self) -> None: medrecord = create_medrecord() medrecord.edge[0, "test"] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[0, ["test", "test2"]] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", + }, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[[0, 1], "test"] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[[0, 1], ["test", "test2"]] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", + }, + 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, "test"] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo", "test": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() >= 2, ["test", "test2"]] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, + } medrecord = create_medrecord() medrecord.edge[:, "test"] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test"}, 2: {"foo": "bar", "bar": "foo", "test": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "test": "test"}, + 2: {"foo": "bar", "bar": "foo", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test"}, + } medrecord = create_medrecord() medrecord.edge[:, ["test", "test2"]] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum", "test": "test", "test2": "test"}, 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}} + assert medrecord.edge[:] == { + 0: { + "foo": "bar", + "bar": "foo", + "lorem": "ipsum", + "test": "test", + "test2": "test", + }, + 1: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 2: {"foo": "bar", "bar": "foo", "test": "test", "test2": "test"}, + 3: {"foo": "bar", "bar": "test", "test": "test", "test2": "test"}, + } # Adding and updating attributes medrecord = create_medrecord() medrecord.edge[[0, 1], "lorem"] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[[0, 1], ["lorem", "test"]] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() < 2, "lorem"] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[edge().index() < 2, ["lorem", "test"]] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() medrecord.edge[:, "lorem"] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, 3: {"foo": "bar", "bar": "test", "lorem": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 2: {"foo": "bar", "bar": "foo", "lorem": "test"}, + 3: {"foo": "bar", "bar": "test", "lorem": "test"}, + } medrecord = create_medrecord() medrecord.edge[:, ["lorem", "test"]] = "test" - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 1: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 2: {"foo": "bar", "bar": "foo", "lorem": "test", "test": "test"}, + 3: {"foo": "bar", "bar": "test", "lorem": "test", "test": "test"}, + } def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[0, "foo"] - assert medrecord.edge[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing from a non-existing edge should fail @@ -856,7 +1298,12 @@ def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[0, ["foo", "bar"]] - assert medrecord.edge[:] == {0: {"lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail @@ -865,7 +1312,12 @@ def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[0, :] - assert medrecord.edge[:] == {0: {}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): del medrecord.edge[0, 1:] @@ -876,7 +1328,12 @@ def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[[0, 1], "foo"] - assert medrecord.edge[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing from a non-existing edge should fail @@ -890,7 +1347,12 @@ def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[[0, 1], ["foo", "bar"]] - assert medrecord.edge[:] == {0: {"lorem": "ipsum"}, 1: {}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"lorem": "ipsum"}, + 1: {}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail @@ -904,7 +1366,12 @@ def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[[0, 1], :] - assert medrecord.edge[:] == {0: {}, 1: {}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {}, + 1: {}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } with pytest.raises(ValueError): del medrecord.edge[[0, 1], 1:] @@ -915,12 +1382,22 @@ def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[edge().index() >= 2, "foo"] - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"bar": "foo"}, 3: {"bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"bar": "foo"}, + 3: {"bar": "test"}, + } medrecord = create_medrecord() # Empty query should not fail del medrecord.edge[edge().index() > 3, "foo"] - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {"foo": "bar", "bar": "foo"}, 3: {"foo": "bar", "bar": "test"}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {"foo": "bar", "bar": "foo"}, + 3: {"foo": "bar", "bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail @@ -929,7 +1406,12 @@ def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[edge().index() >= 2, ["foo", "bar"]] - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {}, 3: {}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {}, + 3: {}, + } medrecord = create_medrecord() # Removing a non-existing key should fail @@ -943,7 +1425,12 @@ def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[edge().index() >= 2, :] - assert medrecord.edge[:] == {0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, 1: {"foo": "bar", "bar": "foo"}, 2: {}, 3: {}} + assert medrecord.edge[:] == { + 0: {"foo": "bar", "bar": "foo", "lorem": "ipsum"}, + 1: {"foo": "bar", "bar": "foo"}, + 2: {}, + 3: {}, + } with pytest.raises(ValueError): del medrecord.edge[edge().index() >= 2, 1:] @@ -954,7 +1441,12 @@ def test_edge_delitem(self) -> None: medrecord = create_medrecord() del medrecord.edge[:, "foo"] - assert medrecord.edge[:] == {0: {"bar": "foo", "lorem": "ipsum"}, 1: {"bar": "foo"}, 2: {"bar": "foo"}, 3: {"bar": "test"}} + assert medrecord.edge[:] == { + 0: {"bar": "foo", "lorem": "ipsum"}, + 1: {"bar": "foo"}, + 2: {"bar": "foo"}, + 3: {"bar": "test"}, + } medrecord = create_medrecord() # Removing a non-existing key should fail diff --git a/medmodels/medrecord/tests/test_medrecord.py b/medmodels/medrecord/tests/test_medrecord.py index 20970043..314ca797 100644 --- a/medmodels/medrecord/tests/test_medrecord.py +++ b/medmodels/medrecord/tests/test_medrecord.py @@ -327,7 +327,10 @@ def test_group(self) -> None: assert medrecord.group("1") == {"nodes": ["0"], "edges": [0]} - assert medrecord.group(["0", "1"]) == {"0": {"nodes": [], "edges": []}, "1": {"nodes": ["0"], "edges": [0]}} + assert medrecord.group(["0", "1"]) == { + "0": {"nodes": [], "edges": []}, + "1": {"nodes": ["0"], "edges": [0]}, + } def test_invalid_group(self) -> None: medrecord = create_medrecord() @@ -351,11 +354,17 @@ def test_outgoing_edges(self) -> None: edges = medrecord.outgoing_edges(["0", "1"]) - assert {key: sorted(value) for key, value in edges.items()} == {"0": sorted([0, 3]), "1": [1, 2]} + assert {key: sorted(value) for key, value in edges.items()} == { + "0": sorted([0, 3]), + "1": [1, 2], + } edges = medrecord.outgoing_edges(node_select().index().is_in(["0", "1"])) - assert {key: sorted(value) for key, value in edges.items()} == {"0": sorted([0, 3]), "1": [1, 2]} + assert {key: sorted(value) for key, value in edges.items()} == { + "0": sorted([0, 3]), + "1": [1, 2], + } def test_invalid_outgoing_edges(self) -> None: medrecord = create_medrecord() @@ -1383,7 +1392,10 @@ def test_groups_of_node(self) -> None: assert medrecord.groups_of_node(["0", "1"]) == {"0": ["0"], "1": ["0"]} - assert medrecord.groups_of_node(node_select().index().is_in(["0", "1"])) == {"0": ["0"], "1": ["0"]} + assert medrecord.groups_of_node(node_select().index().is_in(["0", "1"])) == { + "0": ["0"], + "1": ["0"], + } def test_invalid_groups_of_node(self) -> None: medrecord = create_medrecord() @@ -1405,7 +1417,10 @@ def test_groups_of_edge(self) -> None: assert medrecord.groups_of_edge([0, 1]) == {0: ["0"], 1: ["0"]} - assert medrecord.groups_of_edge(edge_select().index().is_in([0, 1])) == {0: ["0"], 1: ["0"]} + assert medrecord.groups_of_edge(edge_select().index().is_in([0, 1])) == { + 0: ["0"], + 1: ["0"], + } def test_invalid_groups_of_edge(self) -> None: medrecord = create_medrecord() @@ -1480,11 +1495,17 @@ def test_neighbors(self) -> None: neighbors = medrecord.neighbors(["0", "1"]) - assert {key: sorted(value) for key, value in neighbors.items()} == {"0": sorted(["1", "3"]), "1": ["0", "2"]} + assert {key: sorted(value) for key, value in neighbors.items()} == { + "0": sorted(["1", "3"]), + "1": ["0", "2"], + } neighbors = medrecord.neighbors(node_select().index().is_in(["0", "1"])) - assert {key: sorted(value) for key, value in neighbors.items()} == {"0": sorted(["1", "3"]), "1": ["0", "2"]} + assert {key: sorted(value) for key, value in neighbors.items()} == { + "0": sorted(["1", "3"]), + "1": ["0", "2"], + } neighbors = medrecord.neighbors("0", directed=False) @@ -1492,13 +1513,19 @@ def test_neighbors(self) -> None: neighbors = medrecord.neighbors(["0", "1"], directed=False) - assert {key: sorted(value) for key, value in neighbors.items()} == {"0": sorted(["1", "3"]), "1": ["0", "2"]} + assert {key: sorted(value) for key, value in neighbors.items()} == { + "0": sorted(["1", "3"]), + "1": ["0", "2"], + } neighbors = medrecord.neighbors( node_select().index().is_in(["0", "1"]), directed=False ) - assert {key: sorted(value) for key, value in neighbors.items()} == {"0": sorted(["1", "3"]), "1": ["0", "2"]} + assert {key: sorted(value) for key, value in neighbors.items()} == { + "0": sorted(["1", "3"]), + "1": ["0", "2"], + } def test_invalid_neighbors(self) -> None: medrecord = create_medrecord() diff --git a/medmodels/medrecord/tests/test_querying.py b/medmodels/medrecord/tests/test_querying.py index d7b5b7e2..c0464efa 100644 --- a/medmodels/medrecord/tests/test_querying.py +++ b/medmodels/medrecord/tests/test_querying.py @@ -50,47 +50,75 @@ def test_select_nodes_node(self) -> None: assert medrecord.select_nodes(node().has_attribute("lorem")) == ["0"] # Node has outgoing edge with - assert medrecord.select_nodes(node().has_outgoing_edge_with(edge().index().equal(0))) == ["0"] + assert medrecord.select_nodes( + node().has_outgoing_edge_with(edge().index().equal(0)) + ) == ["0"] # Node has incoming edge with - assert medrecord.select_nodes(node().has_incoming_edge_with(edge().index().equal(0))) == ["1"] + assert medrecord.select_nodes( + node().has_incoming_edge_with(edge().index().equal(0)) + ) == ["1"] # Node has edge with - assert sorted(["0", "1"]) == sorted(medrecord.select_nodes(node().has_edge_with(edge().index().equal(0)))) + assert sorted(["0", "1"]) == sorted( + medrecord.select_nodes(node().has_edge_with(edge().index().equal(0))) + ) # Node has neighbor with - assert sorted(["0", "1"]) == sorted(medrecord.select_nodes(node().has_neighbor_with(node().index().equal("2")))) - assert sorted(["0"]) == sorted(medrecord.select_nodes(node().has_neighbor_with(node().index().equal("1"), directed=True))) + assert sorted(["0", "1"]) == sorted( + medrecord.select_nodes(node().has_neighbor_with(node().index().equal("2"))) + ) + assert sorted(["0"]) == sorted( + medrecord.select_nodes( + node().has_neighbor_with(node().index().equal("1"), directed=True) + ) + ) # Node has neighbor with - assert sorted(["0", "2"]) == sorted(medrecord.select_nodes(node().has_neighbor_with(node().index().equal("1"), directed=False))) + assert sorted(["0", "2"]) == sorted( + medrecord.select_nodes( + node().has_neighbor_with(node().index().equal("1"), directed=False) + ) + ) def test_select_nodes_node_index(self) -> None: medrecord = create_medrecord() # Index greater - assert sorted(["2", "3"]) == sorted(medrecord.select_nodes(node().index().greater("1"))) + assert sorted(["2", "3"]) == sorted( + medrecord.select_nodes(node().index().greater("1")) + ) # Index less - assert sorted(["0", "1"]) == sorted(medrecord.select_nodes(node().index().less("2"))) + assert sorted(["0", "1"]) == sorted( + medrecord.select_nodes(node().index().less("2")) + ) # Index greater or equal - assert sorted(["1", "2", "3"]) == sorted(medrecord.select_nodes(node().index().greater_or_equal("1"))) + assert sorted(["1", "2", "3"]) == sorted( + medrecord.select_nodes(node().index().greater_or_equal("1")) + ) # Index less or equal - assert sorted(["0", "1", "2"]) == sorted(medrecord.select_nodes(node().index().less_or_equal("2"))) + assert sorted(["0", "1", "2"]) == sorted( + medrecord.select_nodes(node().index().less_or_equal("2")) + ) # Index equal assert medrecord.select_nodes(node().index().equal("1")) == ["1"] # Index not equal - assert sorted(["0", "2", "3"]) == sorted(medrecord.select_nodes(node().index().not_equal("1"))) + assert sorted(["0", "2", "3"]) == sorted( + medrecord.select_nodes(node().index().not_equal("1")) + ) # Index in assert medrecord.select_nodes(node().index().is_in(["1"])) == ["1"] # Index not in - assert sorted(["0", "2", "3"]) == sorted(medrecord.select_nodes(node().index().not_in(["1"]))) + assert sorted(["0", "2", "3"]) == sorted( + medrecord.select_nodes(node().index().not_in(["1"])) + ) # Index starts with assert medrecord.select_nodes(node().index().starts_with("1")) == ["1"] @@ -113,11 +141,15 @@ def test_select_nodes_node_attribute(self) -> None: assert medrecord.select_nodes(node().attribute("lorem") < "ipsum") == [] # Attribute greater or equal - assert medrecord.select_nodes(node().attribute("lorem").greater_or_equal("ipsum")) == ["0"] + assert medrecord.select_nodes( + node().attribute("lorem").greater_or_equal("ipsum") + ) == ["0"] assert medrecord.select_nodes(node().attribute("lorem") >= "ipsum") == ["0"] # Attribute less or equal - assert medrecord.select_nodes(node().attribute("lorem").less_or_equal("ipsum")) == ["0"] + assert medrecord.select_nodes( + node().attribute("lorem").less_or_equal("ipsum") + ) == ["0"] assert medrecord.select_nodes(node().attribute("lorem") <= "ipsum") == ["0"] # Attribute equal @@ -125,120 +157,332 @@ def test_select_nodes_node_attribute(self) -> None: assert medrecord.select_nodes(node().attribute("lorem") == "ipsum") == ["0"] # Attribute not equal - assert medrecord.select_nodes(node().attribute("lorem").not_equal("ipsum")) == [] + assert ( + medrecord.select_nodes(node().attribute("lorem").not_equal("ipsum")) == [] + ) assert medrecord.select_nodes(node().attribute("lorem") != "ipsum") == [] # Attribute in - assert medrecord.select_nodes(node().attribute("lorem").is_in(["ipsum"])) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem").is_in(["ipsum"])) == [ + "0" + ] # Attribute not in assert medrecord.select_nodes(node().attribute("lorem").not_in(["ipsum"])) == [] # Attribute starts with - assert medrecord.select_nodes(node().attribute("lorem").starts_with("ip")) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem").starts_with("ip")) == [ + "0" + ] # Attribute ends with - assert medrecord.select_nodes(node().attribute("lorem").ends_with("um")) == ["0"] + assert medrecord.select_nodes(node().attribute("lorem").ends_with("um")) == [ + "0" + ] # Attribute contains assert medrecord.select_nodes(node().attribute("lorem").contains("su")) == ["0"] # Attribute compare to attribute - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem"))) == ["0"] - assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem"))) == [] + assert medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("lorem")) + ) == ["0"] + assert ( + medrecord.select_nodes( + node().attribute("lorem").not_equal(node().attribute("lorem")) + ) + == [] + ) # Attribute compare to attribute add - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").add("10"))) == [] - assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") + "10") == [] - assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").add("10"))) == ["0"] - assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") + "10") == ["0"] + assert ( + medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("lorem").add("10")) + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem") == node().attribute("lorem") + "10" + ) + == [] + ) + assert medrecord.select_nodes( + node().attribute("lorem").not_equal(node().attribute("lorem").add("10")) + ) == ["0"] + assert medrecord.select_nodes( + node().attribute("lorem") != node().attribute("lorem") + "10" + ) == ["0"] # Attribute compare to attribute sub # Returns nothing because can't sub a string - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").sub("10"))) == [] - assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") + "10") == [] - assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").sub("10"))) == [] - assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") - "10") == [] + assert ( + medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("lorem").sub("10")) + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem") == node().attribute("lorem") + "10" + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem").not_equal(node().attribute("lorem").sub("10")) + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem") != node().attribute("lorem") - "10" + ) + == [] + ) # Attribute compare to attribute sub - assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").sub(10))) == [] - assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").sub(10))) == ["0"] + assert ( + medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("integer").sub(10)) + ) + == [] + ) + assert medrecord.select_nodes( + node().attribute("integer").not_equal(node().attribute("integer").sub(10)) + ) == ["0"] # Attribute compare to attribute mul - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").mul(2))) == [] - assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") * 2) == [] - assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").mul(2))) == ["0"] - assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") * 2) == ["0"] + assert ( + medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("lorem").mul(2)) + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem") == node().attribute("lorem") * 2 + ) + == [] + ) + assert medrecord.select_nodes( + node().attribute("lorem").not_equal(node().attribute("lorem").mul(2)) + ) == ["0"] + assert medrecord.select_nodes( + node().attribute("lorem") != node().attribute("lorem") * 2 + ) == ["0"] # Attribute compare to attribute div # Returns nothing because can't div a string - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").div("10"))) == [] - assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") / "10") == [] - assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").div("10"))) == [] - assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") / "10") == [] + assert ( + medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("lorem").div("10")) + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem") == node().attribute("lorem") / "10" + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem").not_equal(node().attribute("lorem").div("10")) + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem") != node().attribute("lorem") / "10" + ) + == [] + ) # Attribute compare to attribute div - assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").div(2))) == [] - assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").div(2))) == ["0"] + assert ( + medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("integer").div(2)) + ) + == [] + ) + assert medrecord.select_nodes( + node().attribute("integer").not_equal(node().attribute("integer").div(2)) + ) == ["0"] # Attribute compare to attribute pow # Returns nothing because can't pow a string - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").pow("10"))) == [] - assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") ** "10") == [] - assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").pow("10"))) == [] - assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") ** "10") == [] + assert ( + medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("lorem").pow("10")) + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem") == node().attribute("lorem") ** "10" + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem").not_equal(node().attribute("lorem").pow("10")) + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem") != node().attribute("lorem") ** "10" + ) + == [] + ) # Attribute compare to attribute pow - assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").pow(2))) == ["0"] - assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").pow(2))) == [] + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("integer").pow(2)) + ) == ["0"] + assert ( + medrecord.select_nodes( + node() + .attribute("integer") + .not_equal(node().attribute("integer").pow(2)) + ) + == [] + ) # Attribute compare to attribute mod # Returns nothing because can't mod a string - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").mod("10"))) == [] - assert medrecord.select_nodes(node().attribute("lorem") == node().attribute("lorem") % "10") == [] - assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").mod("10"))) == [] - assert medrecord.select_nodes(node().attribute("lorem") != node().attribute("lorem") % "10") == [] + assert ( + medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("lorem").mod("10")) + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem") == node().attribute("lorem") % "10" + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem").not_equal(node().attribute("lorem").mod("10")) + ) + == [] + ) + assert ( + medrecord.select_nodes( + node().attribute("lorem") != node().attribute("lorem") % "10" + ) + == [] + ) # Attribute compare to attribute mod - assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").mod(2))) == ["0"] - assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").mod(2))) == [] + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("integer").mod(2)) + ) == ["0"] + assert ( + medrecord.select_nodes( + node() + .attribute("integer") + .not_equal(node().attribute("integer").mod(2)) + ) + == [] + ) # Attribute compare to attribute round - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("lorem").round())) == ["0"] - assert medrecord.select_nodes(node().attribute("lorem").not_equal(node().attribute("lorem").round())) == [] - assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("float").round())) == ["0"] - assert medrecord.select_nodes(node().attribute("float").not_equal(node().attribute("float").round())) == ["0"] + assert medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("lorem").round()) + ) == ["0"] + assert ( + medrecord.select_nodes( + node().attribute("lorem").not_equal(node().attribute("lorem").round()) + ) + == [] + ) + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("float").round()) + ) == ["0"] + assert medrecord.select_nodes( + node().attribute("float").not_equal(node().attribute("float").round()) + ) == ["0"] # Attribute compare to attribute round - assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("float").ceil())) == ["0"] - assert medrecord.select_nodes(node().attribute("float").not_equal(node().attribute("float").ceil())) == ["0"] + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("float").ceil()) + ) == ["0"] + assert medrecord.select_nodes( + node().attribute("float").not_equal(node().attribute("float").ceil()) + ) == ["0"] # Attribute compare to attribute floor - assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("float").floor())) == [] - assert medrecord.select_nodes(node().attribute("float").not_equal(node().attribute("float").floor())) == ["0"] + assert ( + medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("float").floor()) + ) + == [] + ) + assert medrecord.select_nodes( + node().attribute("float").not_equal(node().attribute("float").floor()) + ) == ["0"] # Attribute compare to attribute abs - assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").abs())) == ["0"] - assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").abs())) == [] + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("integer").abs()) + ) == ["0"] + assert ( + medrecord.select_nodes( + node().attribute("integer").not_equal(node().attribute("integer").abs()) + ) + == [] + ) # Attribute compare to attribute sqrt - assert medrecord.select_nodes(node().attribute("integer").equal(node().attribute("integer").sqrt())) == ["0"] - assert medrecord.select_nodes(node().attribute("integer").not_equal(node().attribute("integer").sqrt())) == [] + assert medrecord.select_nodes( + node().attribute("integer").equal(node().attribute("integer").sqrt()) + ) == ["0"] + assert ( + medrecord.select_nodes( + node() + .attribute("integer") + .not_equal(node().attribute("integer").sqrt()) + ) + == [] + ) # Attribute compare to attribute trim - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("dolor").trim())) == ["0"] + assert medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("dolor").trim()) + ) == ["0"] # Attribute compare to attribute trim_start - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("dolor").trim_start())) == [] + assert ( + medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("dolor").trim_start()) + ) + == [] + ) # Attribute compare to attribute trim_end - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("dolor").trim_end())) == [] + assert ( + medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("dolor").trim_end()) + ) + == [] + ) # Attribute compare to attribute lowercase - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("test").lowercase())) == ["0"] + assert medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("test").lowercase()) + ) == ["0"] # Attribute compare to attribute uppercase - assert medrecord.select_nodes(node().attribute("lorem").equal(node().attribute("test").uppercase())) == [] + assert ( + medrecord.select_nodes( + node().attribute("lorem").equal(node().attribute("test").uppercase()) + ) + == [] + ) def test_select_edges_edge(self) -> None: medrecord = create_medrecord() @@ -246,10 +490,14 @@ def test_select_edges_edge(self) -> None: medrecord.add_group("test", edges=[0]) # Edge connected to target - assert sorted([1, 2, 3]) == sorted(medrecord.select_edges(edge().connected_target("2"))) + assert sorted([1, 2, 3]) == sorted( + medrecord.select_edges(edge().connected_target("2")) + ) # Edge connected to source - assert sorted([0, 2, 3]) == sorted(medrecord.select_edges(edge().connected_source("0"))) + assert sorted([0, 2, 3]) == sorted( + medrecord.select_edges(edge().connected_source("0")) + ) # Edge connected assert sorted([0, 1]) == sorted(medrecord.select_edges(edge().connected("1"))) @@ -261,46 +509,72 @@ def test_select_edges_edge(self) -> None: assert medrecord.select_edges(edge().has_attribute("sed")) == [0] # Edge connected to target with - assert medrecord.select_edges(edge().connected_target_with(node().index().equal("1"))) == [0] + assert medrecord.select_edges( + edge().connected_target_with(node().index().equal("1")) + ) == [0] # Edge connected to source with - assert sorted([0, 2, 3]) == sorted(medrecord.select_edges(edge().connected_source_with(node().index().equal("0")))) + assert sorted([0, 2, 3]) == sorted( + medrecord.select_edges( + edge().connected_source_with(node().index().equal("0")) + ) + ) # Edge connected with - assert sorted([0, 1]) == sorted(medrecord.select_edges(edge().connected_with(node().index().equal("1")))) + assert sorted([0, 1]) == sorted( + medrecord.select_edges(edge().connected_with(node().index().equal("1"))) + ) # Edge has parallel edges with - assert sorted([2, 3]) == sorted(medrecord.select_edges(edge().has_parallel_edges_with(edge().has_attribute("test")))) + assert sorted([2, 3]) == sorted( + medrecord.select_edges( + edge().has_parallel_edges_with(edge().has_attribute("test")) + ) + ) # Edge has parallel edges with self comparison - assert medrecord.select_edges(edge().has_parallel_edges_with_self_comparison(edge().attribute("test").equal(edge().attribute("test").sub(1)))) == [2] + assert medrecord.select_edges( + edge().has_parallel_edges_with_self_comparison( + edge().attribute("test").equal(edge().attribute("test").sub(1)) + ) + ) == [2] def test_select_edges_edge_index(self) -> None: medrecord = create_medrecord() # Index greater - assert sorted([2, 3]) == sorted(medrecord.select_edges(edge().index().greater(1))) + assert sorted([2, 3]) == sorted( + medrecord.select_edges(edge().index().greater(1)) + ) # Index less assert medrecord.select_edges(edge().index().less(1)) == [0] # Index greater or equal - assert sorted([1, 2, 3]) == sorted(medrecord.select_edges(edge().index().greater_or_equal(1))) + assert sorted([1, 2, 3]) == sorted( + medrecord.select_edges(edge().index().greater_or_equal(1)) + ) # Index less or equal - assert sorted([0, 1]) == sorted(medrecord.select_edges(edge().index().less_or_equal(1))) + assert sorted([0, 1]) == sorted( + medrecord.select_edges(edge().index().less_or_equal(1)) + ) # Index equal assert medrecord.select_edges(edge().index().equal(1)) == [1] # Index not equal - assert sorted([0, 2, 3]) == sorted(medrecord.select_edges(edge().index().not_equal(1))) + assert sorted([0, 2, 3]) == sorted( + medrecord.select_edges(edge().index().not_equal(1)) + ) # Index in assert medrecord.select_edges(edge().index().is_in([1])) == [1] # Index not in - assert sorted([0, 2, 3]) == sorted(medrecord.select_edges(edge().index().not_in([1]))) + assert sorted([0, 2, 3]) == sorted( + medrecord.select_edges(edge().index().not_in([1])) + ) def test_select_edges_edges_attribute(self) -> None: medrecord = create_medrecord() @@ -312,10 +586,14 @@ def test_select_edges_edges_attribute(self) -> None: assert medrecord.select_edges(edge().attribute("sed").less("do")) == [] # Attribute greater or equal - assert medrecord.select_edges(edge().attribute("sed").greater_or_equal("do")) == [0] + assert medrecord.select_edges( + edge().attribute("sed").greater_or_equal("do") + ) == [0] # Attribute less or equal - assert medrecord.select_edges(edge().attribute("sed").less_or_equal("do")) == [0] + assert medrecord.select_edges(edge().attribute("sed").less_or_equal("do")) == [ + 0 + ] # Attribute equal assert medrecord.select_edges(edge().attribute("sed").equal("do")) == [0] @@ -339,98 +617,302 @@ def test_select_edges_edges_attribute(self) -> None: assert medrecord.select_edges(edge().attribute("sed").contains("d")) == [0] # Attribute compare to attribute - assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed"))) == [0] - assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed"))) == [] + assert medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("sed")) + ) == [0] + assert ( + medrecord.select_edges( + edge().attribute("sed").not_equal(edge().attribute("sed")) + ) + == [] + ) # Attribute compare to attribute add - assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed").add("10"))) == [] - assert medrecord.select_edges(edge().attribute("sed") == edge().attribute("sed") + "10") == [] - assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed").add("10"))) == [0] - assert medrecord.select_edges(edge().attribute("sed") != edge().attribute("sed") + "10") == [0] + assert ( + medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("sed").add("10")) + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("sed") == edge().attribute("sed") + "10" + ) + == [] + ) + assert medrecord.select_edges( + edge().attribute("sed").not_equal(edge().attribute("sed").add("10")) + ) == [0] + assert medrecord.select_edges( + edge().attribute("sed") != edge().attribute("sed") + "10" + ) == [0] # Attribute compare to attribute sub # Returns nothing because can't sub a string - assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed").sub("10"))) == [] - assert medrecord.select_edges(edge().attribute("sed") == edge().attribute("sed") - "10") == [] - assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed").sub("10"))) == [] - assert medrecord.select_edges(edge().attribute("sed") != edge().attribute("sed") - "10") == [] + assert ( + medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("sed").sub("10")) + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("sed") == edge().attribute("sed") - "10" + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("sed").not_equal(edge().attribute("sed").sub("10")) + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("sed") != edge().attribute("sed") - "10" + ) + == [] + ) # Attribute compare to attribute sub - assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").sub(10))) == [] - assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").sub(10))) == [2] + assert ( + medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("integer").sub(10)) + ) + == [] + ) + assert medrecord.select_edges( + edge().attribute("integer").not_equal(edge().attribute("integer").sub(10)) + ) == [2] # Attribute compare to attribute mul - assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed").mul(2))) == [] - assert medrecord.select_edges(edge().attribute("sed") == edge().attribute("sed") * 2) == [] - assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed").mul(2))) == [0] - assert medrecord.select_edges(edge().attribute("sed") != edge().attribute("sed") * 2) == [0] + assert ( + medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("sed").mul(2)) + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("sed") == edge().attribute("sed") * 2 + ) + == [] + ) + assert medrecord.select_edges( + edge().attribute("sed").not_equal(edge().attribute("sed").mul(2)) + ) == [0] + assert medrecord.select_edges( + edge().attribute("sed") != edge().attribute("sed") * 2 + ) == [0] # Attribute compare to attribute div # Returns nothing because can't div a string - assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed").div("10"))) == [] - assert medrecord.select_edges(edge().attribute("sed") == edge().attribute("sed") / "10") == [] - assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed").div("10"))) == [] - assert medrecord.select_edges(edge().attribute("sed") != edge().attribute("sed") / "10") == [] + assert ( + medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("sed").div("10")) + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("sed") == edge().attribute("sed") / "10" + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("sed").not_equal(edge().attribute("sed").div("10")) + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("sed") != edge().attribute("sed") / "10" + ) + == [] + ) # Attribute compare to attribute div - assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").div(2))) == [] - assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").div(2))) == [2] + assert ( + medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("integer").div(2)) + ) + == [] + ) + assert medrecord.select_edges( + edge().attribute("integer").not_equal(edge().attribute("integer").div(2)) + ) == [2] # Attribute compare to attribute pow # Returns nothing because can't pow a string - assert medrecord.select_edges(edge().attribute("lorem").equal(edge().attribute("lorem").pow("10"))) == [] - assert medrecord.select_edges(edge().attribute("lorem") == edge().attribute("lorem") ** "10") == [] - assert medrecord.select_edges(edge().attribute("lorem").not_equal(edge().attribute("lorem").pow("10"))) == [] - assert medrecord.select_edges(edge().attribute("lorem") != edge().attribute("lorem") ** "10") == [] + assert ( + medrecord.select_edges( + edge().attribute("lorem").equal(edge().attribute("lorem").pow("10")) + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("lorem") == edge().attribute("lorem") ** "10" + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("lorem").not_equal(edge().attribute("lorem").pow("10")) + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("lorem") != edge().attribute("lorem") ** "10" + ) + == [] + ) # Attribute compare to attribute pow - assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").pow(2))) == [2] - assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").pow(2))) == [] + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("integer").pow(2)) + ) == [2] + assert ( + medrecord.select_edges( + edge() + .attribute("integer") + .not_equal(edge().attribute("integer").pow(2)) + ) + == [] + ) # Attribute compare to attribute mod # Returns nothing because can't mod a string - assert medrecord.select_edges(edge().attribute("lorem").equal(edge().attribute("lorem").mod("10"))) == [] - assert medrecord.select_edges(edge().attribute("lorem") == edge().attribute("lorem") % "10") == [] - assert medrecord.select_edges(edge().attribute("lorem").not_equal(edge().attribute("lorem").mod("10"))) == [] - assert medrecord.select_edges(edge().attribute("lorem") != edge().attribute("lorem") % "10") == [] + assert ( + medrecord.select_edges( + edge().attribute("lorem").equal(edge().attribute("lorem").mod("10")) + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("lorem") == edge().attribute("lorem") % "10" + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("lorem").not_equal(edge().attribute("lorem").mod("10")) + ) + == [] + ) + assert ( + medrecord.select_edges( + edge().attribute("lorem") != edge().attribute("lorem") % "10" + ) + == [] + ) # Attribute compare to attribute mod - assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").mod(2))) == [2] - assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").mod(2))) == [] + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("integer").mod(2)) + ) == [2] + assert ( + medrecord.select_edges( + edge() + .attribute("integer") + .not_equal(edge().attribute("integer").mod(2)) + ) + == [] + ) # Attribute compare to attribute round - assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("sed").round())) == [0] - assert medrecord.select_edges(edge().attribute("sed").not_equal(edge().attribute("sed").round())) == [] - assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("float").round())) == [2] - assert medrecord.select_edges(edge().attribute("float").not_equal(edge().attribute("float").round())) == [2] + assert medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("sed").round()) + ) == [0] + assert ( + medrecord.select_edges( + edge().attribute("sed").not_equal(edge().attribute("sed").round()) + ) + == [] + ) + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("float").round()) + ) == [2] + assert medrecord.select_edges( + edge().attribute("float").not_equal(edge().attribute("float").round()) + ) == [2] # Attribute compare to attribute ceil - assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("float").ceil())) == [2] - assert medrecord.select_edges(edge().attribute("float").not_equal(edge().attribute("float").ceil())) == [2] + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("float").ceil()) + ) == [2] + assert medrecord.select_edges( + edge().attribute("float").not_equal(edge().attribute("float").ceil()) + ) == [2] # Attribute compare to attribute floor - assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("float").floor())) == [] - assert medrecord.select_edges(edge().attribute("float").not_equal(edge().attribute("float").floor())) == [2] + assert ( + medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("float").floor()) + ) + == [] + ) + assert medrecord.select_edges( + edge().attribute("float").not_equal(edge().attribute("float").floor()) + ) == [2] # Attribute compare to attribute abs - assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").abs())) == [2] - assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").abs())) == [] + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("integer").abs()) + ) == [2] + assert ( + medrecord.select_edges( + edge().attribute("integer").not_equal(edge().attribute("integer").abs()) + ) + == [] + ) # Attribute compare to attribute sqrt - assert medrecord.select_edges(edge().attribute("integer").equal(edge().attribute("integer").sqrt())) == [2] - assert medrecord.select_edges(edge().attribute("integer").not_equal(edge().attribute("integer").sqrt())) == [] + assert medrecord.select_edges( + edge().attribute("integer").equal(edge().attribute("integer").sqrt()) + ) == [2] + assert ( + medrecord.select_edges( + edge() + .attribute("integer") + .not_equal(edge().attribute("integer").sqrt()) + ) + == [] + ) # Attribute compare to attribute trim - assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("dolor").trim())) == [0] + assert medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("dolor").trim()) + ) == [0] # Attribute compare to attribute trim_start - assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("dolor").trim_start())) == [] + assert ( + medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("dolor").trim_start()) + ) + == [] + ) # Attribute compare to attribute trim_end - assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("dolor").trim_end())) == [] + assert ( + medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("dolor").trim_end()) + ) + == [] + ) # Attribute compare to attribute lowercase - assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("test").lowercase())) == [0] + assert medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("test").lowercase()) + ) == [0] # Attribute compare to attribute uppercase - assert medrecord.select_edges(edge().attribute("sed").equal(edge().attribute("test").uppercase())) == [] + assert ( + medrecord.select_edges( + edge().attribute("sed").equal(edge().attribute("test").uppercase()) + ) + == [] + ) diff --git a/medmodels/medrecord/tests/test_schema.py b/medmodels/medrecord/tests/test_schema.py index 3dfaa342..e4a72064 100644 --- a/medmodels/medrecord/tests/test_schema.py +++ b/medmodels/medrecord/tests/test_schema.py @@ -16,7 +16,17 @@ def setUp(self) -> None: self.schema = create_medrecord().schema def test_groups(self) -> None: - assert sorted(["diagnosis", "drug", "patient_diagnosis", "patient_drug", "patient_procedure", "patient", "procedure"]) == sorted(self.schema.groups) + assert sorted( + [ + "diagnosis", + "drug", + "patient_diagnosis", + "patient_drug", + "patient_procedure", + "patient", + "procedure", + ] + ) == sorted(self.schema.groups) def test_group(self) -> None: assert isinstance(self.schema.group("patient"), mr.GroupSchema) # pyright: ignore[reportUnnecessaryIsInstance] @@ -39,10 +49,16 @@ def setUp(self) -> None: self.schema = create_medrecord().schema def test_nodes(self) -> None: - assert self.schema.group("patient").nodes == {"age": (mr.Int(), mr.AttributeType.Continuous), "gender": (mr.String(), mr.AttributeType.Categorical)} + assert self.schema.group("patient").nodes == { + "age": (mr.Int(), mr.AttributeType.Continuous), + "gender": (mr.String(), mr.AttributeType.Categorical), + } def test_edges(self) -> None: - assert self.schema.group("patient_diagnosis").edges == {"diagnosis_time": (mr.DateTime(), mr.AttributeType.Temporal), "duration_days": (mr.Option(mr.Float()), mr.AttributeType.Continuous)} + assert self.schema.group("patient_diagnosis").edges == { + "diagnosis_time": (mr.DateTime(), mr.AttributeType.Temporal), + "duration_days": (mr.Option(mr.Float()), mr.AttributeType.Continuous), + } def test_strict(self) -> None: assert True is self.schema.group("patient").strict @@ -60,7 +76,9 @@ def setUp(self) -> None: ) def test_repr(self) -> None: - assert repr(self.attributes_schema) == "{'description': (DataType.String, None)}" + assert ( + repr(self.attributes_schema) == "{'description': (DataType.String, None)}" + ) second_attributes_schema = ( Schema( @@ -77,7 +95,10 @@ def test_repr(self) -> None: .nodes ) - assert repr(second_attributes_schema) == "{'description': (DataType.String, AttributeType.Categorical)}" + assert ( + repr(second_attributes_schema) + == "{'description': (DataType.String, AttributeType.Categorical)}" + ) def test_getitem(self) -> None: assert (mr.String(), None) == self.attributes_schema["description"] @@ -158,14 +179,18 @@ def test_values(self) -> None: assert [(mr.String(), None)] == list(self.attributes_schema.values()) def test_items(self) -> None: - assert [("description", (mr.String(), None))] == list(self.attributes_schema.items()) + assert [("description", (mr.String(), None))] == list( + self.attributes_schema.items() + ) def test_get(self) -> None: assert (mr.String(), None) == self.attributes_schema.get("description") assert None is self.attributes_schema.get("nonexistent") - assert (mr.String(), None) == self.attributes_schema.get("nonexistent", (mr.String(), None)) + assert (mr.String(), None) == self.attributes_schema.get( + "nonexistent", (mr.String(), None) + ) class TestAttributeType(unittest.TestCase): diff --git a/medmodels/treatment_effect/estimate.py b/medmodels/treatment_effect/estimate.py index 65d17fd9..71980781 100644 --- a/medmodels/treatment_effect/estimate.py +++ b/medmodels/treatment_effect/estimate.py @@ -132,25 +132,19 @@ def _check_medrecord(self, medrecord: MedRecord) -> None: f"Patient group {self._treatment_effect._patients_group} not found in " f"the MedRecord. Available groups: {medrecord.groups}" ) - raise ValueError( - msg - ) + raise ValueError(msg) if self._treatment_effect._treatments_group not in medrecord.groups: msg = ( "Treatment group not found in the MedRecord. " f"Available groups: {medrecord.groups}" ) - raise ValueError( - msg - ) + raise ValueError(msg) if self._treatment_effect._outcomes_group not in medrecord.groups: msg = ( "Outcome group not found in the MedRecord." f"Available groups: {medrecord.groups}" ) - raise ValueError( - msg - ) + raise ValueError(msg) def _sort_subjects_in_groups( self, medrecord: MedRecord @@ -541,9 +535,7 @@ def hazard_ratio(self, medrecord: MedRecord) -> float: if hazard_control == 0: msg = "Control hazard rate is zero, cannot calculate hazard ratio." - raise ValueError( - msg - ) + raise ValueError(msg) return hazard_treat / hazard_control diff --git a/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py b/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py index 3c0dd961..19970a28 100644 --- a/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py +++ b/medmodels/treatment_effect/matching/algorithms/classic_distance_models.py @@ -36,9 +36,7 @@ def nearest_neighbor( """ if treated_set.shape[0] * number_of_neighbors > control_set.shape[0]: msg = "The treated set is too large for the given number of neighbors." - raise ValueError( - msg - ) + raise ValueError(msg) if not covariates: covariates = treated_set.columns diff --git a/medmodels/treatment_effect/tests/test_temporal_analysis.py b/medmodels/treatment_effect/tests/test_temporal_analysis.py index a856c78f..35ac6b8c 100644 --- a/medmodels/treatment_effect/tests/test_temporal_analysis.py +++ b/medmodels/treatment_effect/tests/test_temporal_analysis.py @@ -164,7 +164,9 @@ def test_find_reference_time(self) -> None: assert edge == 0 def test_invalid_find_reference_time(self) -> None: - with pytest.raises(ValueError, match="Time attribute not found in the edge attributes"): + with pytest.raises( + ValueError, match="Time attribute not found in the edge attributes" + ): find_reference_edge( self.medrecord, node_index="P1", @@ -174,7 +176,9 @@ def test_invalid_find_reference_time(self) -> None: ) node_index = "P2" - with pytest.raises(ValueError, match=f"No edge found for node {node_index} in this MedRecord"): + with pytest.raises( + ValueError, match=f"No edge found for node {node_index} in this MedRecord" + ): find_reference_edge( self.medrecord, node_index=node_index, @@ -211,7 +215,9 @@ def test_node_in_time_window(self) -> None: assert not node_found2 def test_invalid_node_in_time_window(self) -> None: - with pytest.raises(ValueError, match="Time attribute not found in the edge attributes"): + with pytest.raises( + ValueError, match="Time attribute not found in the edge attributes" + ): find_node_in_time_window( self.medrecord, subject_index="P3", diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index 0c6324d7..287a409f 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -198,20 +198,51 @@ def assert_treatment_effects_equal( assert treatment_effect1._outcomes_group == treatment_effect2._outcomes_group assert treatment_effect1._patients_group == treatment_effect2._patients_group assert treatment_effect1._time_attribute == treatment_effect2._time_attribute - assert treatment_effect1._washout_period_days == treatment_effect2._washout_period_days - assert treatment_effect1._washout_period_reference == treatment_effect2._washout_period_reference + assert ( + treatment_effect1._washout_period_days == treatment_effect2._washout_period_days + ) + assert ( + treatment_effect1._washout_period_reference + == treatment_effect2._washout_period_reference + ) assert treatment_effect1._grace_period_days == treatment_effect2._grace_period_days - assert treatment_effect1._grace_period_reference == treatment_effect2._grace_period_reference - assert treatment_effect1._follow_up_period_days == treatment_effect2._follow_up_period_days - assert treatment_effect1._follow_up_period_reference == treatment_effect2._follow_up_period_reference - assert treatment_effect1._outcome_before_treatment_days == treatment_effect2._outcome_before_treatment_days - assert treatment_effect1._filter_controls_operation == treatment_effect2._filter_controls_operation + assert ( + treatment_effect1._grace_period_reference + == treatment_effect2._grace_period_reference + ) + assert ( + treatment_effect1._follow_up_period_days + == treatment_effect2._follow_up_period_days + ) + assert ( + treatment_effect1._follow_up_period_reference + == treatment_effect2._follow_up_period_reference + ) + assert ( + treatment_effect1._outcome_before_treatment_days + == treatment_effect2._outcome_before_treatment_days + ) + assert ( + treatment_effect1._filter_controls_operation + == treatment_effect2._filter_controls_operation + ) assert treatment_effect1._matching_method == treatment_effect2._matching_method - assert treatment_effect1._matching_essential_covariates == treatment_effect2._matching_essential_covariates - assert treatment_effect1._matching_one_hot_covariates == treatment_effect2._matching_one_hot_covariates + assert ( + treatment_effect1._matching_essential_covariates + == treatment_effect2._matching_essential_covariates + ) + assert ( + treatment_effect1._matching_one_hot_covariates + == treatment_effect2._matching_one_hot_covariates + ) assert treatment_effect1._matching_model == treatment_effect2._matching_model - assert treatment_effect1._matching_number_of_neighbors == treatment_effect2._matching_number_of_neighbors - assert treatment_effect1._matching_hyperparam == treatment_effect2._matching_hyperparam + assert ( + treatment_effect1._matching_number_of_neighbors + == treatment_effect2._matching_number_of_neighbors + ) + assert ( + treatment_effect1._matching_hyperparam == treatment_effect2._matching_hyperparam + ) class TestTreatmentEffect(unittest.TestCase): @@ -264,7 +295,9 @@ def test_check_medrecord(self) -> None: .build() ) - with pytest.raises(ValueError, match="Treatment group not found in the MedRecord"): + with pytest.raises( + ValueError, match="Treatment group not found in the MedRecord" + ): tee.estimate._check_medrecord(medrecord=self.medrecord) tee2 = ( @@ -274,7 +307,9 @@ def test_check_medrecord(self) -> None: .build() ) - with pytest.raises(ValueError, match="Outcome group not found in the MedRecord"): + with pytest.raises( + ValueError, match="Outcome group not found in the MedRecord" + ): tee2.estimate._check_medrecord(medrecord=self.medrecord) patient_group = "subjects" @@ -286,7 +321,10 @@ def test_check_medrecord(self) -> None: .build() ) - with pytest.raises(ValueError, match=f"Patient group {patient_group} not found in the MedRecord"): + with pytest.raises( + ValueError, + match=f"Patient group {patient_group} not found in the MedRecord", + ): tee3.estimate._check_medrecord(medrecord=self.medrecord) def test_find_treated_patients(self) -> None: @@ -551,7 +589,9 @@ def test_outcome_before_treatment(self) -> None: .build() ) - with pytest.raises(ValueError, match="No outcomes found in the MedRecord for group "): + with pytest.raises( + ValueError, match="No outcomes found in the MedRecord for group " + ): tee3._find_outcomes(medrecord=self.medrecord, treated_group=treated_group) def test_filter_controls(self) -> None: @@ -633,7 +673,9 @@ def test_find_controls(self) -> None: self.assertEqual(control_outcome_true, {"P1", "P4", "P7"}) self.assertEqual(control_outcome_false, {"P5", "P8", "P9"}) - with pytest.raises(ValueError, match="No patients found for control groups in this MedRecord."): + with pytest.raises( + ValueError, match="No patients found for control groups in this MedRecord." + ): tee._find_controls( self.medrecord, control_group=patients - treated_group, @@ -650,7 +692,9 @@ def test_find_controls(self) -> None: self.medrecord.add_group("Headache") - with pytest.raises(ValueError, match="No outcomes found in the MedRecord for group."): + with pytest.raises( + ValueError, match="No outcomes found in the MedRecord for group." + ): tee2._find_controls( self.medrecord, control_group=patients - treated_group, diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index 9cce65ad..133174ae 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -256,9 +256,7 @@ def _find_treated_patients(self, medrecord: MedRecord) -> Set[NodeIndex]: ) if not treated_group: msg = "No patients found for the treatment groups in this MedRecord." - raise ValueError( - msg - ) + raise ValueError(msg) return treated_group @@ -294,9 +292,7 @@ def _find_outcomes( outcomes = medrecord.nodes_in_group(self._outcomes_group) if not outcomes: msg = f"No outcomes found in the MedRecord for group {self._outcomes_group}" - raise ValueError( - msg - ) + raise ValueError(msg) for outcome in outcomes: nodes_to_check = set( @@ -463,9 +459,7 @@ def _find_controls( outcomes = medrecord.nodes_in_group(self._outcomes_group) if not outcomes: msg = f"No outcomes found in the MedRecord for group {self._outcomes_group}" - raise ValueError( - msg - ) + raise ValueError(msg) # Finding the patients that had the outcome in the control group for outcome in outcomes: From 4ba6d056c06a3f966a4255f7c5d2ad60c67d4234 Mon Sep 17 00:00:00 2001 From: FloLimebit Date: Fri, 11 Oct 2024 10:12:56 -0600 Subject: [PATCH 6/7] add_edges_to_group fix --- medmodels/medrecord/medrecord.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/medmodels/medrecord/medrecord.py b/medmodels/medrecord/medrecord.py index c7640516..8a28bfcd 100644 --- a/medmodels/medrecord/medrecord.py +++ b/medmodels/medrecord/medrecord.py @@ -799,7 +799,7 @@ def add_edges( if not self.contains_group(group): self.add_group(group) - self.add_edge_to_group(group, edge_indices) + self.add_edges_to_group(group, edge_indices) return edge_indices From 56eb38a928de47f940e413aa150bc7324801156d Mon Sep 17 00:00:00 2001 From: Jakob Kraus Date: Tue, 15 Oct 2024 12:14:28 +0200 Subject: [PATCH 7/7] pr comments --- docs/developer_guide/docstrings.md | 7 +------ docs/developer_guide/example_docstrings.py | 24 ++++++++++++---------- medmodels/medrecord/medrecord.py | 8 -------- 3 files changed, 14 insertions(+), 25 deletions(-) diff --git a/docs/developer_guide/docstrings.md b/docs/developer_guide/docstrings.md index 36215c1b..2874703d 100644 --- a/docs/developer_guide/docstrings.md +++ b/docs/developer_guide/docstrings.md @@ -227,15 +227,13 @@ When writing type definitions in argument docstrings, avoid placing line breaks If a type definition exceeds the maximum line length limit, deactivate the line length rule for that doctstring using [`# noqa: E501`](https://docs.astral.sh/ruff/rules/line-too-long/#error-suppression) after the closing `"""`. This keeps the type annotation intact while preventing the linter from flagging the long line as an error. ::: - :::{admonition} No Boolean Argument Types :type: note -Avoid using booleans as function arguments due to the "boolean trap." Booleans reduce code clarity, as `True` or `False` doesn't explain meaning. They also limit future flexibility — if more than two options are needed, changing the argument type can cause breaking changes. Instead, use an enum or a more descriptive type for better clarity and flexibility. +Avoid using booleans as function arguments due to the "boolean trap". Booleans reduce code clarity, as `True` or `False` doesn't explain meaning. They also limit future flexibility — if more than two options are needed, changing the argument type can cause breaking changes. Instead, use an enum or a more descriptive type for better clarity and flexibility. For more information, you can refer to [Adam Johnson's article](https://adamj.eu/tech/2021/07/10/python-type-hints-how-to-avoid-the-boolean-trap/) which discusses this in detail and provides examples of how to avoid the boolean trap. ::: - ### Return Types Document return types under the `Returns` section. Each return type should include the type and a brief description. @@ -256,7 +254,6 @@ def example_function(param1, param2): Don't add a `Returns` section when the function has no return value (`def fun() -> None`) ::: - ### Examples Provide examples under the `Examples` section by using Sphinx's `.. code-block::` directive within the docstrings. @@ -287,8 +284,6 @@ The second block shows the return value when executing the code. The output valu ``` - - Full Example: ```python diff --git a/docs/developer_guide/example_docstrings.py b/docs/developer_guide/example_docstrings.py index b8e740c6..8c0861f1 100644 --- a/docs/developer_guide/example_docstrings.py +++ b/docs/developer_guide/example_docstrings.py @@ -1,15 +1,17 @@ +"""Example module with docstrings for the developer guide.""" + from __future__ import annotations -from typing import Any, Dict, Iterator +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union def example_function_args( param1: int, - param2: str | int, - optional_param: list[str] | None = None, - *args: float | str, + param2: Union[str, int], + optional_param: Optional[List[str]] = None, + *args: Union[float, str], **kwargs: Dict[str, Any], -) -> tuple[bool, list[str]]: +) -> Tuple[bool, List[str]]: """Example function with PEP 484 type annotations and PEP 563 future annotations. This function shows how to define and document typing for different kinds of @@ -17,18 +19,18 @@ def example_function_args( Args: param1 (int): A required integer parameter. - param2 (str | int): A parameter that can be either a string or an integer. - optional_param (list[str] | None, optional): An optional parameter that accepts - a list of strings. Defaults to None if not provided. - *args (float | str): Variable length argument list that accepts floats or + param2 (Union[str, int]): A parameter that can be either a string or an integer. + optional_param (Optional[List[str]], optional): An optional parameter that + accepts a list of strings. Defaults to None if not provided. + *args (Union[float, str]): Variable length argument list that accepts floats or strings. **kwargs (Dict[str, Any]): Arbitrary keyword arguments as a dictionary of string keys and values of any type. Returns: - tuple[bool, list[str]]: A tuple containing: + Tuple[bool, List[str]]: A tuple containing: - bool: Always True for this example. - - list[str]: A list with a single string describing the received arguments. + - List[str]: A list with a single string describing the received arguments. """ result = ( f"Received: param1={param1}, param2={param2}, optional_param={optional_param}, " diff --git a/medmodels/medrecord/medrecord.py b/medmodels/medrecord/medrecord.py index 8a28bfcd..72205d96 100644 --- a/medmodels/medrecord/medrecord.py +++ b/medmodels/medrecord/medrecord.py @@ -793,14 +793,6 @@ def add_edges( self.add_edges_to_group(group, edge_indices) - if group is None: - return edge_indices - - if not self.contains_group(group): - self.add_group(group) - - self.add_edges_to_group(group, edge_indices) - return edge_indices def add_edges_pandas(